diff --git a/DeepGraft/AttMIL_simple_no_viral.yaml b/DeepGraft/AttMIL_simple_no_viral.yaml
index ac18f9cea8c08d0cf2f82f06f7cec31c7c6cec34..4aeac661a56d260f945d7baaa40cb75c7e8600e1 100644
--- a/DeepGraft/AttMIL_simple_no_viral.yaml
+++ b/DeepGraft/AttMIL_simple_no_viral.yaml
@@ -11,12 +11,12 @@ General:
     frozen_bn: False
     patience: 50
     server: test #train #test
-    log_path: logs/
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
 
 Data:
     dataset_name: custom
     data_shuffle: False
-    data_dir: '/home/ylan/data/DeepGraft/224_128um/'
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
     label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json'
     fold: 1
     nfold: 4
diff --git a/DeepGraft/AttMIL_simple_norm_rej_rest.yaml b/DeepGraft/AttMIL_simple_norm_rej_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8d27c5e7712081664823804fb1ea4b4dc91def26
--- /dev/null
+++ b/DeepGraft/AttMIL_simple_norm_rej_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/limit_20_split_PAS_HE_Jones_norm_rej_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: AttMIL
+    n_classes: 3
+    backbone: simple
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/AttMIL_simple_norm_rest.yaml b/DeepGraft/AttMIL_simple_norm_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8dafb657ed479545690f970953f089a1d16a5502
--- /dev/null
+++ b/DeepGraft/AttMIL_simple_norm_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: AttMIL
+    n_classes: 2
+    backbone: simple
+    in_features: 1024
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0004
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/AttMIL_simple_rej_rest.yaml b/DeepGraft/AttMIL_simple_rej_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b65afc6c5fb6e698eaf3bd6a5ab652aabaaee3ac
--- /dev/null
+++ b/DeepGraft/AttMIL_simple_rej_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_rej_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: AttMIL
+    n_classes: 2
+    backbone: simple
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/AttMIL_simple_tcmr_viral.yaml b/DeepGraft/AttMIL_simple_tcmr_viral.yaml
index 2aa4ed853c9a324ce4ce6c3ca938c91c637b1e97..b6221c1c9c88599f52e14f676d74e36d62986053 100644
--- a/DeepGraft/AttMIL_simple_tcmr_viral.yaml
+++ b/DeepGraft/AttMIL_simple_tcmr_viral.yaml
@@ -11,7 +11,7 @@ General:
     frozen_bn: False
     patience: 20
     server: train #train #test
-    log_path: logs/
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
 
 Data:
     dataset_name: custom
diff --git a/DeepGraft/Chowder_resnet50_bin_1.yaml b/DeepGraft/Chowder_resnet50_bin_1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d94e5781ce5a0f67409962c4d15e825d0787d333
--- /dev/null
+++ b/DeepGraft/Chowder_resnet50_bin_1.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/reduced_split_PAS_bin_1.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 16
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 16
+
+Model:
+    name: Chowder
+    n_classes: 2
+    backbone: resnet50
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/Chowder_resnet50_tcmr_viral.yaml b/DeepGraft/Chowder_resnet50_tcmr_viral.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c1765328896a6793a5f68b11578d9b152b6a4509
--- /dev/null
+++ b/DeepGraft/Chowder_resnet50_tcmr_viral.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 32
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 32
+
+Model:
+    name: Chowder
+    n_classes: 2
+    backbone: resnet50
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransMIL_feat_norm_rest.yaml b/DeepGraft/TransMIL_feat_norm_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dab6568ee8f7895affa58eec70bacc0f8adccfef
--- /dev/null
+++ b/DeepGraft/TransMIL_feat_norm_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: TransMIL
+    n_classes: 2
+    backbone: features
+    in_features: 512
+    out_features: 1024
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0001
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransMIL_resnet50_no_viral.yaml b/DeepGraft/TransMIL_resnet50_no_viral.yaml
index 2e3394a8342caab9824528b5ec254bc6c91b7ba3..b9989c2a5231ee5f5a65629b24d4a5f8ea784f0d 100644
--- a/DeepGraft/TransMIL_resnet50_no_viral.yaml
+++ b/DeepGraft/TransMIL_resnet50_no_viral.yaml
@@ -11,12 +11,12 @@ General:
     frozen_bn: False
     patience: 50
     server: test #train #test
-    log_path: logs/
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
 
 Data:
     dataset_name: custom
     data_shuffle: False
-    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
     label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json'
     fold: 1
     nfold: 4
@@ -41,7 +41,7 @@ Optimizer:
     opt_eps: null 
     opt_betas: null
     momentum: null 
-    weight_decay: 0.00001
+    weight_decay: 0.01
 
 Loss:
     base_loss: CrossEntropyLoss
diff --git a/DeepGraft/TransMIL_resnet50_norm_rej_rest.yaml b/DeepGraft/TransMIL_resnet50_norm_rej_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..98857434cac5d09f25c7e68e05e41abf0c16c325
--- /dev/null
+++ b/DeepGraft/TransMIL_resnet50_norm_rej_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/limit_20_split_PAS_HE_Jones_norm_rej_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: TransMIL
+    n_classes: 3
+    backbone: resnet50
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransMIL_resnet50_norm_rest.yaml b/DeepGraft/TransMIL_resnet50_norm_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b58c6d880b5707fcdd2aa39e65ae2fb77871bfcd
--- /dev/null
+++ b/DeepGraft/TransMIL_resnet50_norm_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_norm_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: TransMIL
+    n_classes: 2
+    backbone: resnet50
+    in_features: 1024
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransMIL_resnet50_rej_rest.yaml b/DeepGraft/TransMIL_resnet50_rej_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..35cfe86ae9c2fbe79b8c0919574139ed83c723dc
--- /dev/null
+++ b/DeepGraft/TransMIL_resnet50_rej_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_rej_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 2
+    backbone: resnet50
+    in_features: 1024
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransMIL_resnet50_rejections.yaml b/DeepGraft/TransMIL_resnet50_rejections.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e8e75561aace9ea467664e9ccfc6e2ceaee13039
--- /dev/null
+++ b/DeepGraft/TransMIL_resnet50_rejections.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/reduced_split_PAS_HE_Jones_rejections.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 3
+    backbone: resnet50
+    in_features: 1024
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
index df205a2342386af0e48aedb777b8e5a5e6aa7aba..19b26f5cfe864fd03516e137c4b3c1d188b1d77e 100644
--- a/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+++ b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
@@ -16,7 +16,7 @@ General:
 Data:
     dataset_name: custom
     data_shuffle: False
-    data_dir: '/home/ylan/data/DeepGraft/224_128um/'
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
     label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
     fold: 1
     nfold: 3
diff --git a/DeepGraft/TransMIL_retccl_norm_rej_rest.yaml b/DeepGraft/TransMIL_retccl_norm_rej_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..52a78ad5fb8d0c2c8a467ed6c990378e052822d5
--- /dev/null
+++ b/DeepGraft/TransMIL_retccl_norm_rej_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rej_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: TransMIL
+    n_classes: 3
+    backbone: retccl
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransMIL_retccl_norm_rest.yaml b/DeepGraft/TransMIL_retccl_norm_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fa9988de796bf5e7f2d8ce418ccac95b6d28863e
--- /dev/null
+++ b/DeepGraft/TransMIL_retccl_norm_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: TransMIL
+    n_classes: 2
+    backbone: retccl
+    in_features: 512
+    out_features: 1024
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransMIL_retccl_rej_rest.yaml b/DeepGraft/TransMIL_retccl_rej_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9bf335ed309c92dbd02f73f41931f0499f22ae1d
--- /dev/null
+++ b/DeepGraft/TransMIL_retccl_rej_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_rej_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 2
+    backbone: retccl
+    in_features: 1024
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransformerMIL_resnet50_abmr_tcmr.yaml b/DeepGraft/TransformerMIL_resnet50_abmr_tcmr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6cdfaff0b8aed8ad4c57a1081d1e3796885128c2
--- /dev/null
+++ b/DeepGraft/TransformerMIL_resnet50_abmr_tcmr.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/limit_20_split_PAS_HE_Jones_abmr_tcmr.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransformerMIL
+    n_classes: 2
+    backbone: resnet50
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransformerMIL_resnet50_all.yaml b/DeepGraft/TransformerMIL_resnet50_all.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2f08bd3a46e094e900d41e23d7aec88db4a75348
--- /dev/null
+++ b/DeepGraft/TransformerMIL_resnet50_all.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/limit_20_split_PAS_HE_Jones_all.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransformerMIL
+    n_classes: 6
+    backbone: resnet50
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransformerMIL_resnet50_no_viral.yaml b/DeepGraft/TransformerMIL_resnet50_no_viral.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..faccfa187ea084c2b0ceb58c6043888e9f670de6
--- /dev/null
+++ b/DeepGraft/TransformerMIL_resnet50_no_viral.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/reduced_split_PAS_HE_Jones_no_viral.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransformerMIL
+    n_classes: 4
+    backbone: resnet50
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0004
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransformerMIL_resnet50_norm_rej_rest.yaml b/DeepGraft/TransformerMIL_resnet50_norm_rej_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a6f0c827a6cd2f4601f9927d4b346a6b5092834f
--- /dev/null
+++ b/DeepGraft/TransformerMIL_resnet50_norm_rej_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/limit_20_split_PAS_HE_Jones_norm_rej_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: TransformerMIL
+    n_classes: 3
+    backbone: resnet50
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransformerMIL_resnet50_norm_rest.yaml b/DeepGraft/TransformerMIL_resnet50_norm_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..755d10408fdb8f9d79394626847f5a2eb294bc30
--- /dev/null
+++ b/DeepGraft/TransformerMIL_resnet50_norm_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_norm_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: TransformerMIL
+    n_classes: 2
+    backbone: resnet50
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransformerMIL_resnet50_rej_rest.yaml b/DeepGraft/TransformerMIL_resnet50_rej_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2f375b315b4c6b4dba8fc17e5d4ca4ae6763f3ce
--- /dev/null
+++ b/DeepGraft/TransformerMIL_resnet50_rej_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_rej_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransformerMIL
+    n_classes: 2
+    backbone: resnet50
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransformerMIL_resnet50_rejections.yaml b/DeepGraft/TransformerMIL_resnet50_rejections.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..87cf4e1dfc6328f631a5cacca6446f36a2bca4de
--- /dev/null
+++ b/DeepGraft/TransformerMIL_resnet50_rejections.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/reduced_split_PAS_HE_Jones_rejections.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransformerMIL
+    n_classes: 3
+    backbone: resnet50
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransformerMIL_resnet50_tcmr_viral.yaml b/DeepGraft/TransformerMIL_resnet50_tcmr_viral.yaml
index 896bef13fb4cde9841f9afcb7e87a7fe6d72dd91..b5a6b68041116ea7726d5e023e769307e18a86bc 100644
--- a/DeepGraft/TransformerMIL_resnet50_tcmr_viral.yaml
+++ b/DeepGraft/TransformerMIL_resnet50_tcmr_viral.yaml
@@ -16,7 +16,7 @@ General:
 Data:
     dataset_name: custom
     data_shuffle: False
-    data_dir: '/home/ylan/data/DeepGraft/224_128um/'
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
     label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
     fold: 1
     nfold: 3
@@ -24,11 +24,11 @@ Data:
 
     train_dataloader:
         batch_size: 1 
-        num_workers: 8
+        num_workers: 32
 
     test_dataloader:
         batch_size: 1
-        num_workers: 8
+        num_workers: 32
 
 Model:
     name: TransformerMIL
diff --git a/DeepGraft/TransformerMIL_resnet50_tcmr_viral_U.yaml b/DeepGraft/TransformerMIL_resnet50_tcmr_viral_U.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fab6e3207510ef8d7b72f6697a8bd2e282e6bf5f
--- /dev/null
+++ b/DeepGraft/TransformerMIL_resnet50_tcmr_viral_U.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_Utrecht.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransformerMIL
+    n_classes: 2
+    backbone: resnet50
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransformerMIL_resnet50_val_RU.yaml b/DeepGraft/TransformerMIL_resnet50_val_RU.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..959cef9dde5b5e3f89af8471c93842c9101136ad
--- /dev/null
+++ b/DeepGraft/TransformerMIL_resnet50_val_RU.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_val_RU.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransformerMIL
+    n_classes: 2
+    backbone: resnet50
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransformerMIL_resnet50_viral_other.yaml b/DeepGraft/TransformerMIL_resnet50_viral_other.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..07ceedae98c1c9b2d28c6fa58fe0e2caf99fbba2
--- /dev/null
+++ b/DeepGraft/TransformerMIL_resnet50_viral_other.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_HE_Jones_viral_other.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: TransformerMIL
+    n_classes: 2
+    backbone: resnet50
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransformerMIL_retccl_bin_1.yaml b/DeepGraft/TransformerMIL_retccl_bin_1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..64bcc7b19fd2f5163614bb9371d3920b4ab983b0
--- /dev/null
+++ b/DeepGraft/TransformerMIL_retccl_bin_1.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /homeStor1/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/homeStor1/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/homeStor1/ylan/DeepGraft/training_tables/reduced_split_PAS_bin_1.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransformerMIL
+    n_classes: 2
+    backbone: retccl
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransformerMIL_retccl_no_viral.yaml b/DeepGraft/TransformerMIL_retccl_no_viral.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5590f9d9449fda73fadb880022efe2c3b6751896
--- /dev/null
+++ b/DeepGraft/TransformerMIL_retccl_no_viral.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/reduced_split_PAS_no_viral.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransformerMIL
+    n_classes: 4
+    backbone: retccl
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransformerMIL_retccl_norm_rej_rest.yaml b/DeepGraft/TransformerMIL_retccl_norm_rej_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..99c94ffe030eeb1be446ef7110b69838896e5a40
--- /dev/null
+++ b/DeepGraft/TransformerMIL_retccl_norm_rej_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/limit_20_split_PAS_HE_Jones_norm_rej_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: TransformerMIL
+    n_classes: 3
+    backbone: retccl
+    in_features: 512
+    out_features: 1024
+
+
+Optimizer:
+    opt: adamw
+    lr: 0.0004
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransformerMIL_retccl_norm_rest.yaml b/DeepGraft/TransformerMIL_retccl_norm_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9c7341d5c0c901be73312bde35f15b1122f2bc78
--- /dev/null
+++ b/DeepGraft/TransformerMIL_retccl_norm_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_norm_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: TransformerMIL
+    n_classes: 2
+    backbone: retccl
+    in_features: 512
+    out_features: 1024
+
+
+Optimizer:
+    opt: adamw
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransformerMIL_retccl_rej_rest.yaml b/DeepGraft/TransformerMIL_retccl_rej_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..20db4da8cfc22daf557025feaa12632e58d2f57c
--- /dev/null
+++ b/DeepGraft/TransformerMIL_retccl_rej_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_rej_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransformerMIL
+    n_classes: 2
+    backbone: retccl
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransformerMIL_retccl_tcmr_viral.yaml b/DeepGraft/TransformerMIL_retccl_tcmr_viral.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2cc80d1aca23f748acac825226323093ede5bbe3
--- /dev/null
+++ b/DeepGraft/TransformerMIL_retccl_tcmr_viral.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransformerMIL
+    n_classes: 2
+    backbone: retccl
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/code/MyLoss/__init__.py b/code/MyLoss/__init__.py
index bd772e346e2d91f749e52282bb72bcb1616b1eff..637208132c62d0ccd6a767840ce6c29cbf63f9f8 100755
--- a/code/MyLoss/__init__.py
+++ b/code/MyLoss/__init__.py
@@ -6,6 +6,7 @@ from .dice_loss import GDiceLoss, GDiceLossV2, SSLoss, SoftDiceLoss,\
      IoULoss, TverskyLoss, FocalTversky_loss, AsymLoss, DC_and_CE_loss,\
          PenaltyGDiceLoss, DC_and_topk_loss, ExpLog_loss
 from .focal_loss import FocalLoss
+from .focal_loss_ori import FocalLoss_Ori
 from .hausdorff import HausdorffDTLoss, HausdorffERLoss
 from .lovasz_loss import LovaszSoftmax
 from .ND_Crossentropy import CrossentropyND, TopKLoss, WeightedCrossEntropyLoss,\
diff --git a/code/MyLoss/__pycache__/__init__.cpython-39.pyc b/code/MyLoss/__pycache__/__init__.cpython-39.pyc
index 119cc4d4a02cb0c100c171e5e23385ccb8309724..7da3e9984ce1b902f3b782c5d227bc06682e938a 100644
Binary files a/code/MyLoss/__pycache__/__init__.cpython-39.pyc and b/code/MyLoss/__pycache__/__init__.cpython-39.pyc differ
diff --git a/code/MyLoss/__pycache__/focal_loss.cpython-39.pyc b/code/MyLoss/__pycache__/focal_loss.cpython-39.pyc
index 4b2734f4edcae0885494fb2002e9678f127ccc8d..c6095b09fb0de1b97c0210ca024f36a0570d93c4 100644
Binary files a/code/MyLoss/__pycache__/focal_loss.cpython-39.pyc and b/code/MyLoss/__pycache__/focal_loss.cpython-39.pyc differ
diff --git a/code/MyLoss/__pycache__/focal_loss_ori.cpython-39.pyc b/code/MyLoss/__pycache__/focal_loss_ori.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a4cf4db95dd0e42102480ca738d4dd95633ddef
Binary files /dev/null and b/code/MyLoss/__pycache__/focal_loss_ori.cpython-39.pyc differ
diff --git a/code/MyLoss/__pycache__/loss_factory.cpython-39.pyc b/code/MyLoss/__pycache__/loss_factory.cpython-39.pyc
index 4a56099c87d6626e35a3f5d5f82502c95139ee5b..db4ebcceee99d56aa365237e82f0e543b7ba5cc1 100644
Binary files a/code/MyLoss/__pycache__/loss_factory.cpython-39.pyc and b/code/MyLoss/__pycache__/loss_factory.cpython-39.pyc differ
diff --git a/code/MyLoss/focal_loss.py b/code/MyLoss/focal_loss.py
index 05171a9753f40c7116803809487ccabc1d9fdec5..2c3c2e22d3e83dd18657cabbc1ab6b0067d44540 100755
--- a/code/MyLoss/focal_loss.py
+++ b/code/MyLoss/focal_loss.py
@@ -90,4 +90,4 @@ class FocalLoss(nn.Module):
             loss = loss.sum()
         return loss
 
-    
+  
diff --git a/code/MyLoss/focal_loss_ori.py b/code/MyLoss/focal_loss_ori.py
new file mode 100644
index 0000000000000000000000000000000000000000..31ea24cf6260db8a0c10da69367b60cbf7f24d98
--- /dev/null
+++ b/code/MyLoss/focal_loss_ori.py
@@ -0,0 +1,86 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class FocalLoss_Ori(nn.Module):
+    """
+    This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
+    'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
+    Focal_Loss= -1*alpha*((1-pt)**gamma)*log(pt)
+    Args:
+        num_class: number of classes
+        alpha: class balance factor
+        gamma:
+        ignore_index:
+        reduction:
+    """
+
+    def __init__(self, num_class, alpha=None, gamma=2, ignore_index=None, reduction='mean'):
+        super(FocalLoss_Ori, self).__init__()
+        self.num_class = num_class
+        self.gamma = gamma
+        self.reduction = reduction
+        self.smooth = 1e-4
+        self.ignore_index = ignore_index
+        self.alpha = alpha
+        if alpha is None:
+            self.alpha = torch.ones(num_class, )
+        elif isinstance(alpha, (int, float)):
+            self.alpha = torch.as_tensor([alpha] * num_class)
+        elif isinstance(alpha, (list, np.ndarray)):
+            self.alpha = torch.as_tensor(alpha)
+        if self.alpha.shape[0] != num_class:
+            raise RuntimeError('the length not equal to number of class')
+
+        # if isinstance(self.alpha, (list, tuple, np.ndarray)):
+        #     assert len(self.alpha) == self.num_class
+        #     self.alpha = torch.Tensor(list(self.alpha))
+        # elif isinstance(self.alpha, (float, int)):
+        #     assert 0 < self.alpha < 1.0, 'alpha should be in `(0,1)`)'
+        #     assert balance_index > -1
+        #     alpha = torch.ones((self.num_class))
+        #     alpha *= 1 - self.alpha
+        #     alpha[balance_index] = self.alpha
+        #     self.alpha = alpha
+        # elif isinstance(self.alpha, torch.Tensor):
+        #     self.alpha = self.alpha
+        # else:
+        #     raise TypeError('Not support alpha type, expect `int|float|list|tuple|torch.Tensor`')
+
+    def forward(self, logit, target):
+        # assert isinstance(self.alpha,torch.Tensor)\
+        N, C = logit.shape[:2]
+        alpha = self.alpha.to(logit.device)
+        prob = F.softmax(logit, dim=1)
+        if prob.dim() > 2:
+            # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
+            prob = prob.view(N, C, -1)
+            prob = prob.transpose(1, 2).contiguous()  # [N,C,d1*d2..] -> [N,d1*d2..,C]
+            prob = prob.view(-1, prob.size(-1))  # [N,d1*d2..,C]-> [N*d1*d2..,C]
+        ori_shp = target.shape
+        target = target.view(-1, 1)  # [N,d1,d2,...]->[N*d1*d2*...,1]
+        valid_mask = None
+        if self.ignore_index is not None:
+            valid_mask = target != self.ignore_index
+            target = target * valid_mask
+
+        # ----------memory saving way--------
+        
+        prob = prob.gather(1, target).view(-1) + self.smooth  # avoid nan
+        logpt = torch.log(prob)
+        # alpha_class = alpha.gather(0, target.view(-1))
+        alpha_class = alpha[target.squeeze().long()]
+        class_weight = -alpha_class * torch.pow(torch.sub(1.0, prob), self.gamma)
+        loss = class_weight * logpt
+        if valid_mask is not None:
+            loss = loss * valid_mask.squeeze()
+
+        if self.reduction == 'mean':
+            loss = loss.mean()
+            if valid_mask is not None:
+                loss = loss.sum() / valid_mask.sum()
+        elif self.reduction == 'none':
+            loss = loss.view(ori_shp)
+        return loss
+
diff --git a/code/MyLoss/loss_factory.py b/code/MyLoss/loss_factory.py
index 8ff6d1814af1bb43a9e8e16a243f6fe87fe476cf..61502bf66fda9c2d2f4ebb8df1566a1d56583b5e 100755
--- a/code/MyLoss/loss_factory.py
+++ b/code/MyLoss/loss_factory.py
@@ -9,6 +9,7 @@ from .dice_loss import GDiceLoss, GDiceLossV2, SSLoss, SoftDiceLoss,\
      IoULoss, TverskyLoss, FocalTversky_loss, AsymLoss, DC_and_CE_loss,\
          PenaltyGDiceLoss, DC_and_topk_loss, ExpLog_loss
 from .focal_loss import FocalLoss
+from .focal_loss_ori import FocalLoss_Ori
 from .hausdorff import HausdorffDTLoss, HausdorffERLoss
 from .lovasz_loss import LovaszSoftmax
 from .ND_Crossentropy import CrossentropyND, TopKLoss, WeightedCrossEntropyLoss,\
@@ -17,20 +18,22 @@ from .poly_loss import PolyLoss
 
 from pytorch_toolbelt import losses as L
 
-def create_loss(args, w1=1.0, w2=0.5):
+def create_loss(args, n_classes, w1=1.0, w2=0.5):
     conf_loss = args.base_loss
-    if args.loss_weight: 
-        weight = torch.tensor(args.loss_weight)
-    else: weight = None
+    # n_classes = args.model.n_classes
+    # if args.loss_weight: 
+    #     weight = torch.tensor(args.loss_weight)
+    # else: weight = None
     ### MulticlassJaccardLoss(classes=np.arange(11)
     # mode = args.base_loss #BINARY_MODE \MULTICLASS_MODE \MULTILABEL_MODE 
     loss = None
+    print(conf_loss)
     if hasattr(nn, conf_loss): 
-        loss = getattr(nn, conf_loss)(weight=weight, label_smoothing=0.5) 
+        loss = getattr(nn, conf_loss)(label_smoothing=0.5) 
         # loss = getattr(nn, conf_loss)(label_smoothing=0.5) 
     #binary loss
     elif conf_loss == "focal":
-        loss = L.BinaryFocalLoss()
+        loss = FocalLoss_Ori(n_classes)
     elif conf_loss == "jaccard":
         loss = L.BinaryJaccardLoss()
     elif conf_loss == "jaccard_log":
diff --git a/code/__pycache__/test_visualize.cpython-39.pyc b/code/__pycache__/test_visualize.cpython-39.pyc
index f22258b82070b835ec1875cf4e46622cbad9198a..e961ce76be403bfa75de922eeedcbaa056c2b97c 100644
Binary files a/code/__pycache__/test_visualize.cpython-39.pyc and b/code/__pycache__/test_visualize.cpython-39.pyc differ
diff --git a/code/models/resnet50.py b/code/datasets/ResNet.py
similarity index 63%
rename from code/models/resnet50.py
rename to code/datasets/ResNet.py
index 89e23d7460df4b37d9455d9ab7a07d4350fd802f..f8fe70648f55b58eacc5da0578821d7842c563f3 100644
--- a/code/models/resnet50.py
+++ b/code/datasets/ResNet.py
@@ -1,6 +1,8 @@
 import torch
 import torch.nn as nn
-from .utils import load_state_dict_from_url
+from torch.hub import load_state_dict_from_url
+import torch.nn.functional as F
+from torch.nn import Parameter
 
 
 __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
@@ -73,27 +75,21 @@ class BasicBlock(nn.Module):
 
 
 class Bottleneck(nn.Module):
-    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
-    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
-    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
-    # This variant is also known as ResNet V1.5 and improves accuracy according to
-    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
-
     expansion = 4
 
     def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
-                 base_width=64, dilation=1, norm_layer=None):
+                 base_width=64, dilation=1, norm_layer=None, momentum_bn=0.1):
         super(Bottleneck, self).__init__()
         if norm_layer is None:
             norm_layer = nn.BatchNorm2d
         width = int(planes * (base_width / 64.)) * groups
         # Both self.conv2 and self.downsample layers downsample the input when stride != 1
         self.conv1 = conv1x1(inplanes, width)
-        self.bn1 = norm_layer(width)
+        self.bn1 = norm_layer(width, momentum=momentum_bn)
         self.conv2 = conv3x3(width, width, stride, groups, dilation)
-        self.bn2 = norm_layer(width)
+        self.bn2 = norm_layer(width, momentum=momentum_bn)
         self.conv3 = conv1x1(width, planes * self.expansion)
-        self.bn3 = norm_layer(planes * self.expansion)
+        self.bn3 = norm_layer(planes * self.expansion, momentum=momentum_bn)
         self.relu = nn.ReLU(inplace=True)
         self.downsample = downsample
         self.stride = stride
@@ -120,12 +116,23 @@ class Bottleneck(nn.Module):
 
         return out
 
+class NormedLinear(nn.Module):
+
+    def __init__(self, in_features, out_features):
+        super(NormedLinear, self).__init__()
+        self.weight = Parameter(torch.Tensor(in_features, out_features))
+        self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
+
+    def forward(self, x):
+        out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0))
+        return out
 
 class ResNet(nn.Module):
 
     def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                  groups=1, width_per_group=64, replace_stride_with_dilation=None,
-                 norm_layer=None):
+                 norm_layer=None, two_branch=False, mlp=False, normlinear=False,
+                 momentum_bn=0.1, attention=False, attention_layers=3, return_attn=False):
         super(ResNet, self).__init__()
         if norm_layer is None:
             norm_layer = nn.BatchNorm2d
@@ -133,6 +140,7 @@ class ResNet(nn.Module):
 
         self.inplanes = 64
         self.dilation = 1
+        self.return_attn = return_attn
         if replace_stride_with_dilation is None:
             # each element in the tuple indicates if we should replace
             # the 2x2 stride with a dilated convolution instead
@@ -142,9 +150,14 @@ class ResNet(nn.Module):
                              "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
         self.groups = groups
         self.base_width = width_per_group
+        self.two_branch = two_branch
+        self.momentum_bn = momentum_bn
+        self.mlp = mlp
+        linear = NormedLinear if normlinear else nn.Linear
+
         self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                                bias=False)
-        self.bn1 = norm_layer(self.inplanes)
+        self.bn1 = norm_layer(self.inplanes, momentum=momentum_bn)
         self.relu = nn.ReLU(inplace=True)
         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
         self.layer1 = self._make_layer(block, 64, layers[0])
@@ -154,8 +167,33 @@ class ResNet(nn.Module):
                                        dilate=replace_stride_with_dilation[1])
         self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                        dilate=replace_stride_with_dilation[2])
+
+        if attention:
+            self.att_branch = self._make_layer(block, 512, attention_layers, 1, attention=True)
+        else:
+            self.att_branch = None
+
         self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
-        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+        if self.mlp:
+            if self.two_branch:
+                self.fc = nn.Sequential(
+                    nn.Linear(512 * block.expansion, 512 * block.expansion),
+                    nn.ReLU()
+                ) 
+                self.instDis = linear(512 * block.expansion, num_classes)
+                self.groupDis = linear(512 * block.expansion, num_classes)
+            else:
+                self.fc = nn.Sequential(
+                    nn.Linear(512 * block.expansion, 512 * block.expansion),
+                    nn.ReLU(),
+                    linear(512 * block.expansion, num_classes)
+                ) 
+        else:
+            self.fc = nn.Linear(512 * block.expansion, num_classes)
+            if self.two_branch:
+                self.groupDis = nn.Linear(512 * block.expansion, num_classes)
+
 
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
@@ -174,7 +212,7 @@ class ResNet(nn.Module):
                 elif isinstance(m, BasicBlock):
                     nn.init.constant_(m.bn2.weight, 0)
 
-    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+    def _make_layer(self, block, planes, blocks, stride=1, dilate=False, attention=False):
         norm_layer = self._norm_layer
         downsample = None
         previous_dilation = self.dilation
@@ -184,22 +222,31 @@ class ResNet(nn.Module):
         if stride != 1 or self.inplanes != planes * block.expansion:
             downsample = nn.Sequential(
                 conv1x1(self.inplanes, planes * block.expansion, stride),
-                norm_layer(planes * block.expansion),
+                norm_layer(planes * block.expansion, momentum=self.momentum_bn),
             )
 
         layers = []
         layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
-                            self.base_width, previous_dilation, norm_layer))
+                            self.base_width, previous_dilation, norm_layer, momentum_bn=self.momentum_bn))
         self.inplanes = planes * block.expansion
         for _ in range(1, blocks):
             layers.append(block(self.inplanes, planes, groups=self.groups,
                                 base_width=self.base_width, dilation=self.dilation,
-                                norm_layer=norm_layer))
+                                norm_layer=norm_layer, momentum_bn=self.momentum_bn))
+
+        if attention:
+            layers.append(nn.Sequential(
+                conv1x1(self.inplanes, 128),
+                nn.BatchNorm2d(128),
+                nn.ReLU(inplace=True),
+                conv1x1(128, 1),
+                nn.BatchNorm2d(1),
+                nn.Sigmoid()
+            ))
 
         return nn.Sequential(*layers)
 
-    def _forward_impl(self, x):
-        # See note [TorchScript super()]
+    def forward(self, x):
         x = self.conv1(x)
         x = self.bn1(x)
         x = self.relu(x)
@@ -209,15 +256,23 @@ class ResNet(nn.Module):
         x = self.layer2(x)
         x = self.layer3(x)
         x = self.layer4(x)
+        if self.att_branch is not None:
+            att_map = self.att_branch(x)
+            x = x + att_map * x
 
         x = self.avgpool(x)
         x = torch.flatten(x, 1)
-        x = self.fc(x)
-
-        return x
-
-    def forward(self, x):
-        return self._forward_impl(x)
+        if self.mlp and self.two_branch:
+            x = self.fc(x)
+            x1 = self.instDis(x)
+            x2 = self.groupDis(x)
+            return [x1, x2]
+        else:
+            x1 = self.fc(x)
+            if self.two_branch:
+                x2 = self.groupDis(x)
+                return [x1, x2]
+            return x1
 
 
 def _resnet(arch, block, layers, pretrained, progress, **kwargs):
@@ -232,7 +287,6 @@ def _resnet(arch, block, layers, pretrained, progress, **kwargs):
 def resnet18(pretrained=False, progress=True, **kwargs):
     r"""ResNet-18 model from
     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
-
     Args:
         pretrained (bool): If True, returns a model pre-trained on ImageNet
         progress (bool): If True, displays a progress bar of the download to stderr
@@ -241,11 +295,9 @@ def resnet18(pretrained=False, progress=True, **kwargs):
                    **kwargs)
 
 
-
 def resnet34(pretrained=False, progress=True, **kwargs):
     r"""ResNet-34 model from
     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
-
     Args:
         pretrained (bool): If True, returns a model pre-trained on ImageNet
         progress (bool): If True, displays a progress bar of the download to stderr
@@ -254,11 +306,9 @@ def resnet34(pretrained=False, progress=True, **kwargs):
                    **kwargs)
 
 
-
 def resnet50(pretrained=False, progress=True, **kwargs):
     r"""ResNet-50 model from
     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
-
     Args:
         pretrained (bool): If True, returns a model pre-trained on ImageNet
         progress (bool): If True, displays a progress bar of the download to stderr
@@ -267,11 +317,9 @@ def resnet50(pretrained=False, progress=True, **kwargs):
                    **kwargs)
 
 
-
 def resnet101(pretrained=False, progress=True, **kwargs):
     r"""ResNet-101 model from
     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
-
     Args:
         pretrained (bool): If True, returns a model pre-trained on ImageNet
         progress (bool): If True, displays a progress bar of the download to stderr
@@ -280,14 +328,70 @@ def resnet101(pretrained=False, progress=True, **kwargs):
                    **kwargs)
 
 
-
 def resnet152(pretrained=False, progress=True, **kwargs):
     r"""ResNet-152 model from
     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
-
     Args:
         pretrained (bool): If True, returns a model pre-trained on ImageNet
         progress (bool): If True, displays a progress bar of the download to stderr
     """
     return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
-                   **kwargs)
\ No newline at end of file
+                   **kwargs)
+
+
+def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
+    r"""ResNeXt-50 32x4d model from
+    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    kwargs['groups'] = 32
+    kwargs['width_per_group'] = 4
+    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
+                   pretrained, progress, **kwargs)
+
+
+def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
+    r"""ResNeXt-101 32x8d model from
+    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    kwargs['groups'] = 32
+    kwargs['width_per_group'] = 8
+    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
+                   pretrained, progress, **kwargs)
+
+
+def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
+    r"""Wide ResNet-50-2 model from
+    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
+    The model is the same as ResNet except for the bottleneck number of channels
+    which is twice larger in every block. The number of channels in outer 1x1
+    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+    channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    kwargs['width_per_group'] = 64 * 2
+    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
+                   pretrained, progress, **kwargs)
+
+
+def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
+    r"""Wide ResNet-101-2 model from
+    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
+    The model is the same as ResNet except for the bottleneck number of channels
+    which is twice larger in every block. The number of channels in outer 1x1
+    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+    channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    kwargs['width_per_group'] = 64 * 2
+    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
+                   pretrained, progress, **kwargs)
\ No newline at end of file
diff --git a/code/datasets/__init__.py b/code/datasets/__init__.py
index 0eb1fe74d53130bd211a4475d70a12507e9cf879..2989858e6e652de44eaa34b0eb5ba798f1ffefdf 100644
--- a/code/datasets/__init__.py
+++ b/code/datasets/__init__.py
@@ -1,3 +1,4 @@
 
 from .custom_jpg_dataloader import JPGMILDataloader
 from .data_interface import MILDataModule
+from .fast_tensor_dl import FastTensorDataLoader
diff --git a/code/datasets/__pycache__/ResNet.cpython-39.pyc b/code/datasets/__pycache__/ResNet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d963e71935c049d798681b7b63ea581b34f4472
Binary files /dev/null and b/code/datasets/__pycache__/ResNet.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/__init__.cpython-39.pyc b/code/datasets/__pycache__/__init__.cpython-39.pyc
index 3735aa6b4ec5cd68471afdec6d2068bc6293a2ed..d67531a6e6443d08e896d436a6b2ba379fe43528 100644
Binary files a/code/datasets/__pycache__/__init__.cpython-39.pyc and b/code/datasets/__pycache__/__init__.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/custom_dataloader.cpython-39.pyc b/code/datasets/__pycache__/custom_dataloader.cpython-39.pyc
index 99e793aa4b84b2a5e9ee1882437346e6e4a33002..6d666116a162ee7799540dff32ce3b4b03829433 100644
Binary files a/code/datasets/__pycache__/custom_dataloader.cpython-39.pyc and b/code/datasets/__pycache__/custom_dataloader.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc b/code/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc
index 1cd6ef3a6a996cd23622c3e29f8ed89d4d0d9038..9f536c5dcea366ca8fa4c041f92b33ca5546e2a4 100644
Binary files a/code/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc and b/code/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/custom_resnet50.cpython-39.pyc b/code/datasets/__pycache__/custom_resnet50.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e83a6cd9d528d97ded3c4ae4e451d5cad70e63ed
Binary files /dev/null and b/code/datasets/__pycache__/custom_resnet50.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/data_interface.cpython-39.pyc b/code/datasets/__pycache__/data_interface.cpython-39.pyc
index 59584d31b16ee5ba7d08220a2174fb5e79e9aaed..2e97390e89b0ccb4905df819317f14850af46697 100644
Binary files a/code/datasets/__pycache__/data_interface.cpython-39.pyc and b/code/datasets/__pycache__/data_interface.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/fast_tensor_dl.cpython-39.pyc b/code/datasets/__pycache__/fast_tensor_dl.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b40eb162a0900b3bf421252bad862d8b72d4916
Binary files /dev/null and b/code/datasets/__pycache__/fast_tensor_dl.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/zarr_feature_dataloader.cpython-39.pyc b/code/datasets/__pycache__/zarr_feature_dataloader.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6316a91dcd30db8d68e2af702b444f46463713d2
Binary files /dev/null and b/code/datasets/__pycache__/zarr_feature_dataloader.cpython-39.pyc differ
diff --git a/code/datasets/custom_dataloader.py b/code/datasets/custom_dataloader.py
index cf65534d749f1cdc764bbdd8193245043c7853f6..f66246188040f435a184fae2cff904c12c2b7c90 100644
--- a/code/datasets/custom_dataloader.py
+++ b/code/datasets/custom_dataloader.py
@@ -18,7 +18,7 @@ import pandas as pd
 import json
 import albumentations as A
 from albumentations.pytorch import ToTensorV2
-from transformers import AutoFeatureExtractor
+# from transformers import AutoFeatureExtractor
 from imgaug import augmenters as iaa
 import imgaug as ia
 from torchsampler import ImbalancedDatasetSampler
diff --git a/code/datasets/custom_jpg_dataloader.py b/code/datasets/custom_jpg_dataloader.py
index b47fc7fbe2b461a5e446106f3d8c5a9edb55ba4a..c28acc1a378e95ad770a045e583f6ee0011e2181 100644
--- a/code/datasets/custom_jpg_dataloader.py
+++ b/code/datasets/custom_jpg_dataloader.py
@@ -3,30 +3,45 @@ ToDo: remove bag_size
 '''
 
 
+# from custom_resnet50 import resnet50_baseline
 import numpy as np
 from pathlib import Path
 import torch
 from torch.utils import data
-from torch.utils.data.dataloader import DataLoader
+from torch.utils.data import random_split, DataLoader
 from tqdm import tqdm
 import torchvision.transforms as transforms
 from PIL import Image
 import cv2
 import json
 import albumentations as A
+from albumentations.pytorch import ToTensorV2
 from imgaug import augmenters as iaa
 import imgaug as ia
 from torchsampler import ImbalancedDatasetSampler
 
 
+
 class RangeNormalization(object):
     def __call__(self, sample):
-        img = sample
-        return (img / 255.0 - 0.5) / 0.5
+
+        MEAN = 255 * torch.tensor([0.485, 0.456, 0.406])
+        STD = 255 * torch.tensor([0.229, 0.224, 0.225])
+
+        # x = torch.from_numpy(sample)
+        x = sample.type(torch.float32)
+        # x = x.permute(-1, 0, 1)
+        x = 2 * x / 255 - 1
+
+        # x = (x - MEAN[:, None, None])/ STD[:, None, None]
+
+        # 
+        
+        return x
 
 class JPGMILDataloader(data.Dataset):
     
-    def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=10, max_bag_size=1296):
+    def __init__(self, file_path, label_path, mode, n_classes, cache=False, data_cache_size=100, max_bag_size=1000):
         super().__init__()
 
         self.data_info = []
@@ -42,15 +57,16 @@ class JPGMILDataloader(data.Dataset):
         self.max_bag_size = max_bag_size
         self.min_bag_size = 120
         self.empty_slides = []
+        self.corrupt_slides = []
         # self.label_file = label_path
         recursive = True
         
         # read labels and slide_path from csv
         with open(self.label_path, 'r') as f:
             temp_slide_label_dict = json.load(f)[mode]
+            print(len(temp_slide_label_dict))
             for (x, y) in temp_slide_label_dict:
                 x = Path(x).stem 
-
                 # x_complete_path = Path(self.file_path)/Path(x)
                 for cohort in Path(self.file_path).iterdir():
                     x_complete_path = Path(self.file_path) / cohort / 'BLOCKS' / Path(x)
@@ -60,12 +76,20 @@ class JPGMILDataloader(data.Dataset):
                             self.slideLabelDict[x] = y
                             self.files.append(x_complete_path)
                         else: self.empty_slides.append(x_complete_path)
-        # print(len(self.empty_slides))
-        # print(self.empty_slides)
+                    
 
+        print(f'Slides with bag size under {self.min_bag_size}: ', len(self.empty_slides))
+        # print(self.empty_slides)
+        # print(len(self.files))
+        # print(len(self.corrupt_slides))
+        # print(self.corrupt_slides)
+        home = Path.cwd().parts[1]
+        self.slide_patient_dict_path = f'/{home}/ylan/DeepGraft/training_tables/slide_patient_dict.json'
+        with open(self.slide_patient_dict_path, 'r') as f:
+            slide_patient_dict = json.load(f)
 
         for slide_dir in tqdm(self.files):
-            self._add_data_infos(str(slide_dir.resolve()), load_data)
+            self._add_data_infos(str(slide_dir.resolve()), cache, slide_patient_dict)
 
 
         self.resize_transforms = A.Compose([
@@ -90,7 +114,24 @@ class JPGMILDataloader(data.Dataset):
             ], name="MyOneOf")
 
         ], name="MyAug")
-
+        self.albu_transforms = A.Compose([
+            A.HueSaturationValue(hue_shift_limit=30, sat_shift_limit=30, always_apply=False, p=0.5),
+            A.ColorJitter(always_apply=False, p=0.5),
+            A.RandomGamma(gamma_limit=(80,120)),
+            A.Flip(p=0.5),
+            A.RandomRotate90(p=0.5),
+            # A.OneOf([
+            #     A.ElasticTransform(alpha=150, sigma=20, alpha_affine=50),
+            #     A.Affine(
+            #         scale={'x': (0.95, 1.05), 'y': (0.95, 1.05)},
+            #         rotate=(-45, 45),
+            #         shear=(-4, 4),
+            #         cval=8,
+            #         )
+            # ]),
+            A.Normalize(),
+            ToTensorV2(),
+        ])
         # self.train_transforms = A.Compose([
         #     A.HueSaturationValue(hue_shift_limit=13, sat_shift_limit=2, val_shift_limit=0, always_apply=True, p=1.0),
         #     # A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=0, val_shift_limit=0, always_apply=False, p=0.5),
@@ -111,11 +152,13 @@ class JPGMILDataloader(data.Dataset):
         #     ToTensorV2(),
         # ])
         self.val_transforms = transforms.Compose([
-            # A.Normalize(),
-            # ToTensorV2(),
-            RangeNormalization(),
+            # 
             transforms.ToTensor(),
-
+            transforms.Normalize(
+                mean=[0.485, 0.456, 0.406],
+                std=[0.229, 0.224, 0.225],
+            ),
+            # RangeNormalization(),
         ])
         self.img_transforms = transforms.Compose([    
             transforms.RandomHorizontalFlip(p=1),
@@ -130,7 +173,7 @@ class JPGMILDataloader(data.Dataset):
 
     def __getitem__(self, index):
         # get data
-        (batch, batch_names), label, name = self.get_data(index)
+        (batch, batch_names), label, name, patient = self.get_data(index)
         out_batch = []
         seq_img_d = self.train_transforms.to_deterministic()
         
@@ -139,10 +182,15 @@ class JPGMILDataloader(data.Dataset):
             # print(.shape)
             for img in batch: # expects numpy 
                 img = img.numpy().astype(np.uint8)
+                # img = self.albu_transforms(image=img)
+                # print(img)
                 # print(img.shape)
                 img = seq_img_d.augment_image(img)
-                img = self.val_transforms(img)
+                img = self.val_transforms(img.copy())
+                # print(img)
                 out_batch.append(img)
+                # img = self.albu_transforms(image=img)
+                # out_batch.append(img['image'])
 
         else:
             for img in batch:
@@ -154,6 +202,7 @@ class JPGMILDataloader(data.Dataset):
         #     # print(name)
         #     out_batch = torch.randn(self.bag_size,3,256,256)
         # else: 
+        # print(len(out_batch))
         out_batch = torch.stack(out_batch)
         # print(out_batch.shape)
         # out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch
@@ -165,19 +214,21 @@ class JPGMILDataloader(data.Dataset):
         label = torch.as_tensor(label)
         label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
         # print(out_batch)
-        return out_batch, label, (name, batch_names) #, name_batch
+        return out_batch, label, (name, batch_names, patient) #, name_batch
 
     def __len__(self):
         return len(self.data_info)
     
-    def _add_data_infos(self, file_path, load_data):
+    def _add_data_infos(self, file_path, cache, slide_patient_dict):
+
+        
         wsi_name = Path(file_path).stem
         if wsi_name in self.slideLabelDict:
             # if wsi_name[:2] != 'RU': #skip RU because of container problems in dataset
             label = self.slideLabelDict[wsi_name]
-            # print(wsi_name)
+            patient = slide_patient_dict[wsi_name]
             idx = -1
-            self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx})
+            self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'patient': patient,'cache_idx': idx})
 
     def _load_data(self, file_path):
         """Load data to the cache given the file
@@ -186,20 +237,13 @@ class JPGMILDataloader(data.Dataset):
         """
         wsi_batch = []
         name_batch = []
-        # print(wsi_batch)
-        # for tile_path in Path(file_path).iterdir():
-        #     print(tile_path)
         for tile_path in Path(file_path).iterdir():
-            # print(tile_path)
             img = np.asarray(Image.open(tile_path)).astype(np.uint8)
             img = torch.from_numpy(img)
-
-            # print(wsi_batch)
             wsi_batch.append(img)
             
             name_batch.append(tile_path.stem)
                 
-        # if wsi_batch:
         wsi_batch = torch.stack(wsi_batch)
         if len(wsi_batch.shape) < 4: 
             wsi_batch.unsqueeze(0)
@@ -213,6 +257,7 @@ class JPGMILDataloader(data.Dataset):
         #     print(wsi_batch.shape)
         if wsi_batch.size(0) > self.max_bag_size:
             wsi_batch, name_batch, _ = to_fixed_size_bag(wsi_batch, name_batch, self.max_bag_size)
+        wsi_batch, name_batch = self.data_dropout(wsi_batch, name_batch, drop_rate=0.1)
         idx = self._add_to_cache((wsi_batch,name_batch), file_path)
         file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path)
         self.data_info[file_idx + idx]['cache_idx'] = idx
@@ -225,7 +270,7 @@ class JPGMILDataloader(data.Dataset):
             self.data_cache.pop(removal_keys[0])
             # remove invalid cache_idx
             # self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'shape': di['shape'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
-            self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
+            self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'name': di['name'], 'patient':di['patient'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
 
     def _add_to_cache(self, data, data_path):
         """Adds data to the cache and returns its index. There is one cache
@@ -269,9 +314,18 @@ class JPGMILDataloader(data.Dataset):
         cache_idx = self.data_info[i]['cache_idx']
         label = self.data_info[i]['label']
         name = self.data_info[i]['name']
+        patient = self.data_info[i]['patient']
+        
         # print(self.data_cache[fp][cache_idx])
-        return self.data_cache[fp][cache_idx], label, name
+        return self.data_cache[fp][cache_idx], label, name, patient
 
+    def data_dropout(self, bag, batch_names, drop_rate):
+        bag_size = bag.shape[0]
+        bag_idxs = torch.randperm(bag_size)[:int(bag_size*(1-drop_rate))]
+        bag_samples = bag[bag_idxs]
+        name_samples = [batch_names[i] for i in bag_idxs]
+
+        return bag_samples, name_samples
 
 
 class RandomHueSaturationValue(object):
@@ -318,7 +372,7 @@ def to_fixed_size_bag(bag, names, bag_size: int = 512):
 
     # zero-pad if we don't have enough samples
     # zero_padded = torch.cat((bag_samples,
-                            # torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+    #                         torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
 
     return bag_samples, name_samples, min(bag_size, len(bag))
 
@@ -355,28 +409,45 @@ class RandomHueSaturationValue(object):
 if __name__ == '__main__':
     from pathlib import Path
     import os
+    import time
+    from fast_tensor_dl import FastTensorDataLoader
+    from custom_resnet50 import resnet50_baseline
+    
+    
 
     home = Path.cwd().parts[1]
     train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
-    data_root = f'/{home}/ylan/data/DeepGraft/224_128um'
+    data_root = f'/{home}/ylan/data/DeepGraft/224_128um_v2'
     # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
     # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
-    label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/split_debug.json'
+    # label_path = f'/{home}/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_norm_rest.json'
     output_dir = f'/{data_root}/debug/augments'
     os.makedirs(output_dir, exist_ok=True)
 
     n_classes = 2
 
-    dataset = JPGMILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes)
+    dataset = JPGMILDataloader(data_root, label_path=label_path, mode='train', cache=False, n_classes=n_classes)
+    a = int(len(dataset)* 0.8)
+    b = int(len(dataset) - a)
+    train_data, valid_data = random_split(dataset, [a, b])
     # print(dataset.dataset)
     # a = int(len(dataset)* 0.8)
     # b = int(len(dataset) - a)
     # train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b])
-    dl = DataLoader(dataset, None, num_workers=1, shuffle=False)
+    # dl = FastTensorDataLoader(dataset, batch_size=1, shuffle=False)
+    dl = DataLoader(train_data, batch_size=1, num_workers=16, sampler=ImbalancedDatasetSampler(train_data), pin_memory=True)
     # print(len(dl))
     # dl = DataLoader(dataset, batch_size=1, sampler=ImbalancedDatasetSampler(dataset), num_workers=5)
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    scaler = torch.cuda.amp.GradScaler()
 
+    model_ft = resnet50_baseline(pretrained=True)
+    for param in model_ft.parameters():
+        param.requires_grad = False
+    model_ft.to(device)
     
+
     
     # data = DataLoader(dataset, batch_size=1)
 
@@ -385,12 +456,27 @@ if __name__ == '__main__':
     #/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/RU0248_PASD_jke_PASD_20200201_195900_BIG.hdf5
     c = 0
     label_count = [0] *n_classes
-    print(len(dl))
-    for item in dl: 
-        if c >= 5:
-            break
-        bag, label, (name, _) = item
-        label_count[torch.argmax(label)] += 1
+    # print(len(dl))
+    start = time.time()
+    for item in tqdm(dl): 
+
+        # if c >= 10:
+        #     break
+        bag, label, (name, batch_names, patient) = item
+        print(bag.shape)
+        print(len(batch_names))
+        bag = bag.squeeze(0).float().to(device)
+        label = label.to(device)
+        with torch.cuda.amp.autocast():
+            output = model_ft(bag)
+        c += 1
+    end = time.time()
+
+    print('Bag Time: ', end-start)
+        # print(label)
+        # print(name)
+        # print(patient)
+    #     label_count[torch.argmax(label)] += 1
         # print(name)
         # if name == 'RU0248_PASD_jke_PASD_20200201_195900_BIG':
         
@@ -414,23 +500,23 @@ if __name__ == '__main__':
     #     # bag = item[0]
     #     bag = bag.squeeze()
     #     original = original.squeeze()
-        output_path = Path(output_dir) / name
-        output_path.mkdir(exist_ok=True)
-        for i in range(bag.shape[0]):
-            img = bag[i, :, :, :]
-            img = img.squeeze()
+        # output_path = Path(output_dir) / name
+        # output_path.mkdir(exist_ok=True)
+        # for i in range(bag.shape[0]):
+        #     img = bag[i, :, :, :]
+        #     img = img.squeeze()
             
-            img = ((img-img.min())/(img.max() - img.min())) * 255
-            # print(img)
-            # print(img)
-            img = img.numpy().astype(np.uint8).transpose(1,2,0)
+        #     img = ((img-img.min())/(img.max() - img.min())) * 255
+        #     # print(img)
+        #     # print(img)
+        #     img = img.numpy().astype(np.uint8).transpose(1,2,0)
 
             
-            img = Image.fromarray(img)
-            img = img.convert('RGB')
-            img.save(f'{output_path}/{i}.png')
+        #     img = Image.fromarray(img)
+        #     img = img.convert('RGB')
+        #     img.save(f'{output_path}/{i}.png')
 
-        c += 1
+        # c += 1
             
     #         o_img = original[i,:,:,:]
     #         o_img = o_img.squeeze()
@@ -449,4 +535,4 @@ if __name__ == '__main__':
     # b = torch.stack(a)
     # print(b)
     # c = to_fixed_size_bag(b, 512)
-    # print(c)
\ No newline at end of file
+    # print(c)
diff --git a/code/datasets/custom_npy_dataloader.py b/code/datasets/custom_npy_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2c99e7a08f796dd210e2d967317681a07ae16c5
--- /dev/null
+++ b/code/datasets/custom_npy_dataloader.py
@@ -0,0 +1,526 @@
+'''
+ToDo: remove bag_size
+'''
+
+
+from custom_resnet50 import resnet50_baseline
+import numpy as np
+from pathlib import Path
+import torch
+from torch.utils import data
+from torch.utils.data import random_split, DataLoader
+from tqdm import tqdm
+import torchvision.transforms as transforms
+from PIL import Image
+import cv2
+import json
+import albumentations as A
+from albumentations.pytorch import ToTensorV2
+from imgaug import augmenters as iaa
+import imgaug as ia
+from torchsampler import ImbalancedDatasetSampler
+
+
+
+class RangeNormalization(object):
+    def __call__(self, sample):
+        img = sample
+        return (img / 255.0 - 0.5) / 0.5
+
+class NPYMILDataloader(data.Dataset):
+    
+    def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=500, max_bag_size=1000):
+        super().__init__()
+
+        self.data_info = []
+        self.data_cache = {}
+        self.slideLabelDict = {}
+        self.files = []
+        self.data_cache_size = data_cache_size
+        self.mode = mode
+        self.file_path = file_path
+        # self.csv_path = csv_path
+        self.label_path = label_path
+        self.n_classes = n_classes
+        self.max_bag_size = max_bag_size
+        self.min_bag_size = 120
+        self.empty_slides = []
+        self.corrupt_slides = []
+        # self.label_file = label_path
+        recursive = True
+        exclude_cohorts = ["debug", "test"]
+        # read labels and slide_path from csv
+        with open(self.label_path, 'r') as f:
+            temp_slide_label_dict = json.load(f)[mode]
+            print(len(temp_slide_label_dict))
+            for (x, y) in temp_slide_label_dict:
+                x = Path(x).stem 
+                # x_complete_path = Path(self.file_path)/Path(x)
+                for cohort in Path(self.file_path).iterdir():
+                    if cohort.stem not in exclude_cohorts:
+                        x_complete_path = cohort / 'BLOCKS'
+                        for slide_dir in Path(x_complete_path).iterdir():
+                            if slide_dir.suffix == '.npy':
+                                
+                                if slide_dir.stem == x:
+                            # if len(list(x_complete_path.iterdir())) > self.min_bag_size:
+                            # print(x_complete_path)
+                                    self.slideLabelDict[x] = y
+                                    self.files.append(slide_dir)
+                            # else: self.empty_slides.append(x_complete_path)
+                    
+        print(len(self.files))
+
+        print(f'Slides with bag size under {self.min_bag_size}: ', len(self.empty_slides))
+        # print(self.empty_slides)
+        # print(len(self.files))
+        # print(len(self.corrupt_slides))
+        # print(self.corrupt_slides)
+        home = Path.cwd().parts[1]
+        slide_patient_dict_path = f'/{home}/ylan/DeepGraft/training_tables/slide_patient_dict.json'
+        with open(slide_patient_dict_path, 'r') as f:
+            slide_patient_dict = json.load(f)
+
+        for slide_dir in tqdm(self.files):
+            self._add_data_infos(str(slide_dir.resolve()), load_data, slide_patient_dict)
+
+
+        self.resize_transforms = A.Compose([
+            A.SmallestMaxSize(max_size=256)
+        ])
+        sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1")
+        sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2")
+        sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3")
+        sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4")
+        sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5")
+
+        self.train_transforms = iaa.Sequential([
+            iaa.AddToHueAndSaturation(value=(-30, 30), name="MyHSV"), #13
+            sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")),
+            iaa.Fliplr(0.5, name="MyFlipLR"),
+            iaa.Flipud(0.5, name="MyFlipUD"),
+            sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")),
+            # iaa.OneOf([
+            #     sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")),
+            #     sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")),
+            #     sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine"))
+            # ], name="MyOneOf")
+
+        ], name="MyAug")
+        self.albu_transforms = A.Compose([
+            A.HueSaturationValue(hue_shift_limit=30, sat_shift_limit=30, always_apply=False, p=0.5),
+            A.ColorJitter(always_apply=False, p=0.5),
+            A.RandomGamma(gamma_limit=(80,120)),
+            A.Flip(p=0.5),
+            A.RandomRotate90(p=0.5),
+            # A.OneOf([
+            #     A.ElasticTransform(alpha=150, sigma=20, alpha_affine=50),
+            #     A.Affine(
+            #         scale={'x': (0.95, 1.05), 'y': (0.95, 1.05)},
+            #         rotate=(-45, 45),
+            #         shear=(-4, 4),
+            #         cval=8,
+            #         )
+            # ]),
+            A.Normalize(),
+            ToTensorV2(),
+        ])
+        # self.train_transforms = A.Compose([
+        #     A.HueSaturationValue(hue_shift_limit=13, sat_shift_limit=2, val_shift_limit=0, always_apply=True, p=1.0),
+        #     # A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=0, val_shift_limit=0, always_apply=False, p=0.5),
+        #     # A.RandomGamma(),
+        #     # A.HorizontalFlip(),
+        #     # A.VerticalFlip(),
+        #     # A.RandomRotate90(),
+        #     # A.OneOf([
+        #     #     A.ElasticTransform(alpha=150, sigma=20, alpha_affine=50),
+        #     #     A.Affine(
+        #     #         scale={'x': (0.95, 1.05), 'y': (0.95, 1.05)},
+        #     #         rotate=(-45, 45),
+        #     #         shear=(-4, 4),
+        #     #         cval=8,
+        #     #         )
+        #     # ]),
+        #     A.Normalize(),
+        #     ToTensorV2(),
+        # ])
+        self.val_transforms = transforms.Compose([
+            # A.Normalize(),
+            # ToTensorV2(),
+            RangeNormalization(),
+            transforms.ToTensor(),
+
+        ])
+        self.img_transforms = transforms.Compose([    
+            transforms.RandomHorizontalFlip(p=1),
+            transforms.RandomVerticalFlip(p=1),
+            # histoTransforms.AutoRandomRotation(),
+            transforms.Lambda(lambda a: np.array(a)),
+        ]) 
+        self.hsv_transforms = transforms.Compose([
+            RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)),
+            transforms.ToTensor()
+        ])
+
+    def __getitem__(self, index):
+        # get data
+        (batch, batch_names), label, name, patient = self.get_data(index)
+        out_batch = []
+        seq_img_d = self.train_transforms.to_deterministic()
+        
+        if self.mode == 'train':
+            # print(img)
+            # print(.shape)
+            for img in batch: # expects numpy 
+                # img = img.numpy().astype(np.uint8)
+                # img = self.albu_transforms(image=img)
+                # print(img)
+                # print(img.shape)
+                # print(img)
+                img = img.astype(np.uint8)
+
+                img = seq_img_d.augment_image(img)
+                img = self.val_transforms(img)
+                out_batch.append(img)
+                # img = self.albu_transforms(image=img)
+                # out_batch.append(img['image'])
+
+        else:
+            for img in batch:
+                img = img.numpy().astype(np.uint8)
+                img = self.val_transforms(img)
+                out_batch.append(img)
+
+        # if len(out_batch) == 0:
+        #     # print(name)
+        #     out_batch = torch.randn(self.bag_size,3,256,256)
+        # else: 
+        # print(len(out_batch))
+        out_batch = torch.stack(out_batch)
+        # print(out_batch.shape)
+        # out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch
+        # print(out_batch.shape)
+        # if out_batch.shape != torch.Size([self.bag_size, 256, 256, 3]) and out_batch.shape != torch.Size([self.bag_size, 3,256,256]):
+        #     print(name)
+        #     print(out_batch.shape)
+        # out_batch = torch.permute(out_batch, (0, 2,1,3))
+        label = torch.as_tensor(label)
+        label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+        # print(out_batch)
+        return out_batch, label, (name, batch_names, patient) #, name_batch
+
+    def __len__(self):
+        return len(self.data_info)
+    
+    def _add_data_infos(self, file_path, load_data, slide_patient_dict):
+
+        
+        wsi_name = Path(file_path).stem
+        if wsi_name in self.slideLabelDict:
+            # if wsi_name[:2] != 'RU': #skip RU because of container problems in dataset
+            label = self.slideLabelDict[wsi_name]
+            patient = slide_patient_dict[wsi_name]
+            idx = -1
+            self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'patient': patient,'cache_idx': idx})
+
+    def _load_data(self, file_path):
+        """Load data to the cache given the file
+        path and update the cache index in the
+        data_info structure.
+        """
+        wsi_batch = []
+        name_batch = []
+        # print(wsi_batch)
+        # for tile_path in Path(file_path).iterdir():
+        #     print(tile_path)
+        wsi_batch = np.load(file_path)
+        
+
+        if wsi_batch.shape[0] > self.max_bag_size:
+            wsi_batch, name_batch, _ = to_fixed_size_bag(wsi_batch, name_batch, self.max_bag_size)
+        idx = self._add_to_cache((wsi_batch,name_batch), file_path)
+        file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path)
+        self.data_info[file_idx + idx]['cache_idx'] = idx
+
+        # remove an element from data cache if size was exceeded
+        if len(self.data_cache) > self.data_cache_size:
+            # remove one item from the cache at random
+            removal_keys = list(self.data_cache)
+            removal_keys.remove(file_path)
+            self.data_cache.pop(removal_keys[0])
+            # remove invalid cache_idx
+            # self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'shape': di['shape'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
+            self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'name': di['name'], 'patient':di['patient'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
+
+    def _add_to_cache(self, data, data_path):
+        """Adds data to the cache and returns its index. There is one cache
+        list for every file_path, containing all datasets in that file.
+        """
+        if data_path not in self.data_cache:
+            self.data_cache[data_path] = [data]
+        else:
+            self.data_cache[data_path].append(data)
+        return len(self.data_cache[data_path]) - 1
+
+    # def get_data_infos(self, type):
+    #     """Get data infos belonging to a certain type of data.
+    #     """
+    #     data_info_type = [di for di in self.data_info if di['type'] == type]
+    #     return data_info_type
+
+    def get_name(self, i):
+        # name = self.get_data_infos(type)[i]['name']
+        name = self.data_info[i]['name']
+        return name
+
+    def get_labels(self, indices):
+
+        return [self.data_info[i]['label'] for i in indices]
+        # return self.slideLabelDict.values()
+
+    def get_data(self, i):
+        """Call this function anytime you want to access a chunk of data from the
+            dataset. This will make sure that the data is loaded in case it is
+            not part of the data cache.
+            i = index
+        """
+        # fp = self.get_data_infos(type)[i]['data_path']
+        fp = self.data_info[i]['data_path']
+        if fp not in self.data_cache:
+            self._load_data(fp)
+        
+        # get new cache_idx assigned by _load_data_info
+        # cache_idx = self.get_data_infos(type)[i]['cache_idx']
+        cache_idx = self.data_info[i]['cache_idx']
+        label = self.data_info[i]['label']
+        name = self.data_info[i]['name']
+        patient = self.data_info[i]['patient']
+        # print(self.data_cache[fp][cache_idx])
+        return self.data_cache[fp][cache_idx], label, name, patient
+
+
+
+class RandomHueSaturationValue(object):
+
+    def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5):
+        
+        self.hue_shift_limit = hue_shift_limit
+        self.sat_shift_limit = sat_shift_limit
+        self.val_shift_limit = val_shift_limit
+        self.p = p
+
+    def __call__(self, sample):
+    
+        img = sample #,lbl
+    
+        if np.random.random() < self.p:
+            img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32
+            h, s, v = cv2.split(img)
+            hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1)
+            hue_shift = np.uint8(hue_shift)
+            h += hue_shift
+            sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1])
+            s = cv2.add(s, sat_shift)
+            val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1])
+            v = cv2.add(v, val_shift)
+            img = cv2.merge((h, s, v))
+            img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
+        return img #, lbl
+
+def to_fixed_size_bag(bag, names, bag_size: int = 512):
+
+    #duplicate bag instances unitl 
+
+    # get up to bag_size elements
+    rng = np.random.default_rng()
+    # if bag.shape[0] > bag_size:
+    rng.shuffle(bag)
+    bag_samples = bag[:bag_size, :, :, :]
+        # bag_samples = bag[bag_idxs]
+    # else: 
+    #     original_size = bag.shape[0]
+    #     q, r = divmod(bag_size - bag.shape[0], bag.shape[0])
+    #     # print(bag_size)
+    #     # print(bag.shape[0])
+    #     # print(q, r)
+
+    #     if q > 0:
+    #         bag = np.concatenate([bag for _ in range(q)])
+    #     if r > 0: 
+    #         temp = bag 
+    #         rng.shuffle(temp)
+    #         temp = temp[:r, :, :, :]
+    #         bag = np.concatenate((bag, temp))
+    #     bag_samples = bag
+    name_batch = [''] * bag_size
+    # name_samples = [names[i] for i in bag_idxs]
+    # bag_sample_names = [bag_names[i] for i in bag_idxs]
+    # q, r  = divmod(bag_size, bag_samples.shape[0])
+    # if q > 0:
+    #     bag_samples = torch.cat([bag_samples]*q, 0)
+    
+    # self_padded = torch.cat([bag_samples, bag_samples[:r,:, :, :]])
+
+    # zero-pad if we don't have enough samples
+    # zero_padded = torch.cat((bag_samples,
+    #                         torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+
+    return bag_samples, name_batch, min(bag_size, len(bag))
+
+
+class RandomHueSaturationValue(object):
+
+    def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5):
+        
+        self.hue_shift_limit = hue_shift_limit
+        self.sat_shift_limit = sat_shift_limit
+        self.val_shift_limit = val_shift_limit
+        self.p = p
+
+    def __call__(self, sample):
+    
+        img = sample #,lbl
+    
+        if np.random.random() < self.p:
+            img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32
+            h, s, v = cv2.split(img)
+            hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1)
+            hue_shift = np.uint8(hue_shift)
+            h += hue_shift
+            sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1])
+            s = cv2.add(s, sat_shift)
+            val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1])
+            v = cv2.add(v, val_shift)
+            img = cv2.merge((h, s, v))
+            img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
+        return img #, lbl
+
+
+
+if __name__ == '__main__':
+    from pathlib import Path
+    import os
+    import time
+    from fast_tensor_dl import FastTensorDataLoader
+    
+    
+
+    home = Path.cwd().parts[1]
+    train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
+    data_root = f'/{home}/ylan/data/DeepGraft/224_128um_v2'
+    # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_norm_rest.json'
+    output_dir = f'/{data_root}/debug/augments'
+    os.makedirs(output_dir, exist_ok=True)
+
+    n_classes = 2
+
+    dataset = NPYMILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes)
+    a = int(len(dataset)* 0.8)
+    b = int(len(dataset) - a)
+    train_data, valid_data = random_split(dataset, [a, b])
+    # print(dataset.dataset)
+    # a = int(len(dataset)* 0.8)
+    # b = int(len(dataset) - a)
+    # train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b])
+    # dl = FastTensorDataLoader(dataset, batch_size=1, shuffle=False)
+    dl = DataLoader(train_data, batch_size=1, num_workers=16, sampler=ImbalancedDatasetSampler(train_data), pin_memory=True)
+    # print(len(dl))
+    # dl = DataLoader(dataset, batch_size=1, sampler=ImbalancedDatasetSampler(dataset), num_workers=5)
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    scaler = torch.cuda.amp.GradScaler()
+
+    model_ft = resnet50_baseline(pretrained=True)
+    for param in model_ft.parameters():
+        param.requires_grad = False
+    model_ft.to(device)
+    
+
+    
+    # data = DataLoader(dataset, batch_size=1)
+
+    # print(len(dataset))
+    # # x = 0
+    #/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/RU0248_PASD_jke_PASD_20200201_195900_BIG.hdf5
+    c = 0
+    label_count = [0] *n_classes
+    # print(len(dl))
+    start = time.time()
+    for item in tqdm(dl): 
+
+        
+        # if c >= 10:
+        #     break
+        bag, label, (name, _, patient) = item
+        bag = bag.squeeze(0).float().to(device)
+        label = label.to(device)
+        with torch.cuda.amp.autocast():
+            output = model_ft(bag)
+        c += 1
+    end = time.time()
+
+    print('Bag Time: ', end-start)
+        # print(label)
+        # print(name)
+        # print(patient)
+    #     label_count[torch.argmax(label)] += 1
+        # print(name)
+        # if name == 'RU0248_PASD_jke_PASD_20200201_195900_BIG':
+        
+            # print(bag)
+            # print(label)
+        
+    #     # # print(bag.shape)
+    #     # if bag.shape[1] == 1:
+    #     #     print(name)
+    #     #     print(bag.shape)
+        # print(bag.shape)
+        
+    #     # out_dir = Path(output_path) / name
+    #     # os.makedirs(out_dir, exist_ok=True)
+
+    #     # # print(item[2])
+    #     # # print(len(item))
+    #     # # print(item[1])
+    #     # # print(data.shape)
+    #     # # data = data.squeeze()
+    #     # bag = item[0]
+    #     bag = bag.squeeze()
+    #     original = original.squeeze()
+        # output_path = Path(output_dir) / name
+        # output_path.mkdir(exist_ok=True)
+        # for i in range(bag.shape[0]):
+        #     img = bag[i, :, :, :]
+        #     img = img.squeeze()
+            
+        #     img = ((img-img.min())/(img.max() - img.min())) * 255
+        #     # print(img)
+        #     # print(img)
+        #     img = img.numpy().astype(np.uint8).transpose(1,2,0)
+
+            
+        #     img = Image.fromarray(img)
+        #     img = img.convert('RGB')
+        #     img.save(f'{output_path}/{i}.png')
+
+        # c += 1
+            
+    #         o_img = original[i,:,:,:]
+    #         o_img = o_img.squeeze()
+    #         print(o_img.shape)
+    #         o_img = ((o_img-o_img.min())/(o_img.max()-o_img.min()))*255
+    #         o_img = o_img.numpy().astype(np.uint8).transpose(1,2,0)
+    #         o_img = Image.fromarray(o_img)
+    #         o_img = o_img.convert('RGB')
+    #         o_img.save(f'{output_path}/{i}_original.png')
+        
+    #     break
+        # else: break
+        # print(data.shape)
+        # print(label)
+    # a = [torch.Tensor((3,256,256))]*3
+    # b = torch.stack(a)
+    # print(b)
+    # c = to_fixed_size_bag(b, 512)
+    # print(c)
\ No newline at end of file
diff --git a/code/datasets/custom_resnet50.py b/code/datasets/custom_resnet50.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c3f33518bc53a1611b94767bc50fbbb169154d0
--- /dev/null
+++ b/code/datasets/custom_resnet50.py
@@ -0,0 +1,122 @@
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+import torch
+import torch.nn.functional as F
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+           'resnet152']
+
+model_urls = {
+    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+}
+
+class Bottleneck_Baseline(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck_Baseline, self).__init__()
+        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+                               padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+class ResNet_Baseline(nn.Module):
+
+    def __init__(self, block, layers):
+        self.inplanes = 64
+        super(ResNet_Baseline, self).__init__()
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+                               bias=False)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.avgpool = nn.AdaptiveAvgPool2d(1) 
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2d(self.inplanes, planes * block.expansion,
+                          kernel_size=1, stride=stride, bias=False),
+                nn.BatchNorm2d(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+
+        return x
+
+def resnet50_baseline(pretrained=False):
+    """Constructs a Modified ResNet-50 model.
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet_Baseline(Bottleneck_Baseline, [3, 4, 6, 3])
+    if pretrained:
+        model = load_pretrained_weights(model, 'resnet50')
+    return model
+
+def load_pretrained_weights(model, name):
+    pretrained_dict = model_zoo.load_url(model_urls[name])
+    model.load_state_dict(pretrained_dict, strict=False)
+    return model
diff --git a/code/datasets/custom_zarr_dataloader.py b/code/datasets/custom_zarr_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bdf0906adaea225fb630038412ac709b1c125ec
--- /dev/null
+++ b/code/datasets/custom_zarr_dataloader.py
@@ -0,0 +1,528 @@
+'''
+ToDo: remove bag_size
+'''
+
+
+from custom_resnet50 import resnet50_baseline
+import numpy as np
+from pathlib import Path
+import torch
+from torch.utils import data
+from torch.utils.data import random_split, DataLoader
+from tqdm import tqdm
+import torchvision.transforms as transforms
+from PIL import Image
+import cv2
+import json
+import albumentations as A
+from albumentations.pytorch import ToTensorV2
+from imgaug import augmenters as iaa
+import imgaug as ia
+from torchsampler import ImbalancedDatasetSampler
+import zarr
+
+
+
+class RangeNormalization(object):
+    def __call__(self, sample):
+        img = sample
+        return (img / 255.0 - 0.5) / 0.5
+
+class ZarrMILDataloader(data.Dataset):
+    
+    def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=500, max_bag_size=1000):
+        super().__init__()
+
+        self.data_info = []
+        self.data_cache = {}
+        self.slideLabelDict = {}
+        self.files = []
+        self.data_cache_size = data_cache_size
+        self.mode = mode
+        self.file_path = file_path
+        # self.csv_path = csv_path
+        self.label_path = label_path
+        self.n_classes = n_classes
+        self.max_bag_size = max_bag_size
+        self.min_bag_size = 120
+        self.empty_slides = []
+        self.corrupt_slides = []
+        # self.label_file = label_path
+        recursive = True
+        exclude_cohorts = ["debug", "test"]
+        # read labels and slide_path from csv
+        with open(self.label_path, 'r') as f:
+            temp_slide_label_dict = json.load(f)[mode]
+            print(len(temp_slide_label_dict))
+            for (x, y) in temp_slide_label_dict:
+                x = Path(x).stem 
+                # x_complete_path = Path(self.file_path)/Path(x)
+                for cohort in Path(self.file_path).iterdir():
+                    if cohort.stem not in exclude_cohorts:
+                        x_complete_path = cohort / 'BLOCKS'
+                        for slide_dir in Path(x_complete_path).iterdir():
+                            if slide_dir.suffix == '.zarr':
+                                
+                                if slide_dir.stem == x:
+                            # if len(list(x_complete_path.iterdir())) > self.min_bag_size:
+                            # print(x_complete_path)
+                                    self.slideLabelDict[x] = y
+                                    self.files.append(slide_dir)
+                            # else: self.empty_slides.append(x_complete_path)
+                    
+        print(len(self.files))
+
+        print(f'Slides with bag size under {self.min_bag_size}: ', len(self.empty_slides))
+        # print(self.empty_slides)
+        # print(len(self.files))
+        # print(len(self.corrupt_slides))
+        # print(self.corrupt_slides)
+        home = Path.cwd().parts[1]
+        slide_patient_dict_path = f'/{home}/ylan/DeepGraft/training_tables/slide_patient_dict.json'
+        with open(slide_patient_dict_path, 'r') as f:
+            slide_patient_dict = json.load(f)
+
+        for slide_dir in tqdm(self.files):
+            self._add_data_infos(str(slide_dir.resolve()), load_data, slide_patient_dict)
+
+
+        self.resize_transforms = A.Compose([
+            A.SmallestMaxSize(max_size=256)
+        ])
+        sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1")
+        sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2")
+        sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3")
+        sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4")
+        sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5")
+
+        self.train_transforms = iaa.Sequential([
+            iaa.AddToHueAndSaturation(value=(-30, 30), name="MyHSV"), #13
+            sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")),
+            iaa.Fliplr(0.5, name="MyFlipLR"),
+            iaa.Flipud(0.5, name="MyFlipUD"),
+            sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")),
+            # iaa.OneOf([
+            #     sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")),
+            #     sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")),
+            #     sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine"))
+            # ], name="MyOneOf")
+
+        ], name="MyAug")
+        self.albu_transforms = A.Compose([
+            A.HueSaturationValue(hue_shift_limit=30, sat_shift_limit=30, always_apply=False, p=0.5),
+            A.ColorJitter(always_apply=False, p=0.5),
+            A.RandomGamma(gamma_limit=(80,120)),
+            A.Flip(p=0.5),
+            A.RandomRotate90(p=0.5),
+            # A.OneOf([
+            #     A.ElasticTransform(alpha=150, sigma=20, alpha_affine=50),
+            #     A.Affine(
+            #         scale={'x': (0.95, 1.05), 'y': (0.95, 1.05)},
+            #         rotate=(-45, 45),
+            #         shear=(-4, 4),
+            #         cval=8,
+            #         )
+            # ]),
+            A.Normalize(),
+            ToTensorV2(),
+        ])
+        # self.train_transforms = A.Compose([
+        #     A.HueSaturationValue(hue_shift_limit=13, sat_shift_limit=2, val_shift_limit=0, always_apply=True, p=1.0),
+        #     # A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=0, val_shift_limit=0, always_apply=False, p=0.5),
+        #     # A.RandomGamma(),
+        #     # A.HorizontalFlip(),
+        #     # A.VerticalFlip(),
+        #     # A.RandomRotate90(),
+        #     # A.OneOf([
+        #     #     A.ElasticTransform(alpha=150, sigma=20, alpha_affine=50),
+        #     #     A.Affine(
+        #     #         scale={'x': (0.95, 1.05), 'y': (0.95, 1.05)},
+        #     #         rotate=(-45, 45),
+        #     #         shear=(-4, 4),
+        #     #         cval=8,
+        #     #         )
+        #     # ]),
+        #     A.Normalize(),
+        #     ToTensorV2(),
+        # ])
+        self.val_transforms = transforms.Compose([
+            # A.Normalize(),
+            # ToTensorV2(),
+            RangeNormalization(),
+            transforms.ToTensor(),
+
+        ])
+        self.img_transforms = transforms.Compose([    
+            transforms.RandomHorizontalFlip(p=1),
+            transforms.RandomVerticalFlip(p=1),
+            # histoTransforms.AutoRandomRotation(),
+            transforms.Lambda(lambda a: np.array(a)),
+        ]) 
+        self.hsv_transforms = transforms.Compose([
+            RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)),
+            transforms.ToTensor()
+        ])
+
+    def __getitem__(self, index):
+        # get data
+        (batch, batch_names), label, name, patient = self.get_data(index)
+        out_batch = []
+        seq_img_d = self.train_transforms.to_deterministic()
+        
+        if self.mode == 'train':
+            # print(img)
+            # print(.shape)
+            for img in batch: # expects numpy 
+                # img = img.numpy().astype(np.uint8)
+                # img = self.albu_transforms(image=img)
+                # print(img)
+                # print(img.shape)
+                # print(img)
+                img = img.astype(np.uint8)
+
+                img = seq_img_d.augment_image(img)
+                img = self.val_transforms(img)
+                out_batch.append(img)
+                # img = self.albu_transforms(image=img)
+                # out_batch.append(img['image'])
+
+        else:
+            for img in batch:
+                img = img.numpy().astype(np.uint8)
+                img = self.val_transforms(img)
+                out_batch.append(img)
+
+        # if len(out_batch) == 0:
+        #     # print(name)
+        #     out_batch = torch.randn(self.bag_size,3,256,256)
+        # else: 
+        # print(len(out_batch))
+        out_batch = torch.stack(out_batch)
+        # print(out_batch.shape)
+        # out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch
+        # print(out_batch.shape)
+        # if out_batch.shape != torch.Size([self.bag_size, 256, 256, 3]) and out_batch.shape != torch.Size([self.bag_size, 3,256,256]):
+        #     print(name)
+        #     print(out_batch.shape)
+        # out_batch = torch.permute(out_batch, (0, 2,1,3))
+        label = torch.as_tensor(label)
+        label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+        # print(out_batch)
+        return out_batch, label, (name, batch_names, patient) #, name_batch
+
+    def __len__(self):
+        return len(self.data_info)
+    
+    def _add_data_infos(self, file_path, load_data, slide_patient_dict):
+
+        
+        wsi_name = Path(file_path).stem
+        if wsi_name in self.slideLabelDict:
+            # if wsi_name[:2] != 'RU': #skip RU because of container problems in dataset
+            label = self.slideLabelDict[wsi_name]
+            patient = slide_patient_dict[wsi_name]
+            idx = -1
+            self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'patient': patient,'cache_idx': idx})
+
+    def _load_data(self, file_path):
+        """Load data to the cache given the file
+        path and update the cache index in the
+        data_info structure.
+        """
+        wsi_batch = []
+        name_batch = []
+        # print(wsi_batch)
+        # for tile_path in Path(file_path).iterdir():
+        #     print(tile_path)
+        # wsi_batch = np.load(file_path)
+        wsi_batch = zarr.open(file_path, 'r')[:]
+        
+
+        if wsi_batch.shape[0] > self.max_bag_size:
+            wsi_batch, name_batch, _ = to_fixed_size_bag(wsi_batch, name_batch, self.max_bag_size)
+        idx = self._add_to_cache((wsi_batch,name_batch), file_path)
+        file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path)
+        self.data_info[file_idx + idx]['cache_idx'] = idx
+
+        # remove an element from data cache if size was exceeded
+        if len(self.data_cache) > self.data_cache_size:
+            # remove one item from the cache at random
+            removal_keys = list(self.data_cache)
+            removal_keys.remove(file_path)
+            self.data_cache.pop(removal_keys[0])
+            # remove invalid cache_idx
+            # self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'shape': di['shape'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
+            self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'name': di['name'], 'patient':di['patient'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
+
+    def _add_to_cache(self, data, data_path):
+        """Adds data to the cache and returns its index. There is one cache
+        list for every file_path, containing all datasets in that file.
+        """
+        if data_path not in self.data_cache:
+            self.data_cache[data_path] = [data]
+        else:
+            self.data_cache[data_path].append(data)
+        return len(self.data_cache[data_path]) - 1
+
+    # def get_data_infos(self, type):
+    #     """Get data infos belonging to a certain type of data.
+    #     """
+    #     data_info_type = [di for di in self.data_info if di['type'] == type]
+    #     return data_info_type
+
+    def get_name(self, i):
+        # name = self.get_data_infos(type)[i]['name']
+        name = self.data_info[i]['name']
+        return name
+
+    def get_labels(self, indices):
+
+        return [self.data_info[i]['label'] for i in indices]
+        # return self.slideLabelDict.values()
+
+    def get_data(self, i):
+        """Call this function anytime you want to access a chunk of data from the
+            dataset. This will make sure that the data is loaded in case it is
+            not part of the data cache.
+            i = index
+        """
+        # fp = self.get_data_infos(type)[i]['data_path']
+        fp = self.data_info[i]['data_path']
+        if fp not in self.data_cache:
+            self._load_data(fp)
+        
+        # get new cache_idx assigned by _load_data_info
+        # cache_idx = self.get_data_infos(type)[i]['cache_idx']
+        cache_idx = self.data_info[i]['cache_idx']
+        label = self.data_info[i]['label']
+        name = self.data_info[i]['name']
+        patient = self.data_info[i]['patient']
+        # print(self.data_cache[fp][cache_idx])
+        return self.data_cache[fp][cache_idx], label, name, patient
+
+
+
+class RandomHueSaturationValue(object):
+
+    def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5):
+        
+        self.hue_shift_limit = hue_shift_limit
+        self.sat_shift_limit = sat_shift_limit
+        self.val_shift_limit = val_shift_limit
+        self.p = p
+
+    def __call__(self, sample):
+    
+        img = sample #,lbl
+    
+        if np.random.random() < self.p:
+            img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32
+            h, s, v = cv2.split(img)
+            hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1)
+            hue_shift = np.uint8(hue_shift)
+            h += hue_shift
+            sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1])
+            s = cv2.add(s, sat_shift)
+            val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1])
+            v = cv2.add(v, val_shift)
+            img = cv2.merge((h, s, v))
+            img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
+        return img #, lbl
+
+def to_fixed_size_bag(bag, names, bag_size: int = 512):
+
+    #duplicate bag instances unitl 
+
+    # get up to bag_size elements
+    rng = np.random.default_rng()
+    # if bag.shape[0] > bag_size:
+    rng.shuffle(bag)
+    bag_samples = bag[:bag_size, :, :, :]
+        # bag_samples = bag[bag_idxs]
+    # else: 
+    #     original_size = bag.shape[0]
+    #     q, r = divmod(bag_size - bag.shape[0], bag.shape[0])
+    #     # print(bag_size)
+    #     # print(bag.shape[0])
+    #     # print(q, r)
+
+    #     if q > 0:
+    #         bag = np.concatenate([bag for _ in range(q)])
+    #     if r > 0: 
+    #         temp = bag 
+    #         rng.shuffle(temp)
+    #         temp = temp[:r, :, :, :]
+    #         bag = np.concatenate((bag, temp))
+    #     bag_samples = bag
+    name_batch = [''] * bag_size
+    # name_samples = [names[i] for i in bag_idxs]
+    # bag_sample_names = [bag_names[i] for i in bag_idxs]
+    # q, r  = divmod(bag_size, bag_samples.shape[0])
+    # if q > 0:
+    #     bag_samples = torch.cat([bag_samples]*q, 0)
+    
+    # self_padded = torch.cat([bag_samples, bag_samples[:r,:, :, :]])
+
+    # zero-pad if we don't have enough samples
+    # zero_padded = torch.cat((bag_samples,
+    #                         torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+
+    return bag_samples, name_batch, min(bag_size, len(bag))
+
+
+class RandomHueSaturationValue(object):
+
+    def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5):
+        
+        self.hue_shift_limit = hue_shift_limit
+        self.sat_shift_limit = sat_shift_limit
+        self.val_shift_limit = val_shift_limit
+        self.p = p
+
+    def __call__(self, sample):
+    
+        img = sample #,lbl
+    
+        if np.random.random() < self.p:
+            img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32
+            h, s, v = cv2.split(img)
+            hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1)
+            hue_shift = np.uint8(hue_shift)
+            h += hue_shift
+            sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1])
+            s = cv2.add(s, sat_shift)
+            val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1])
+            v = cv2.add(v, val_shift)
+            img = cv2.merge((h, s, v))
+            img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
+        return img #, lbl
+
+
+
+if __name__ == '__main__':
+    from pathlib import Path
+    import os
+    import time
+    from fast_tensor_dl import FastTensorDataLoader
+    
+    
+
+    home = Path.cwd().parts[1]
+    train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
+    data_root = f'/{home}/ylan/data/DeepGraft/224_128um_v2'
+    # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_norm_rest.json'
+    output_dir = f'/{data_root}/debug/augments'
+    os.makedirs(output_dir, exist_ok=True)
+
+    n_classes = 2
+
+    dataset = ZarrMILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes)
+    a = int(len(dataset)* 0.8)
+    b = int(len(dataset) - a)
+    train_data, valid_data = random_split(dataset, [a, b])
+    # print(dataset.dataset)
+    # a = int(len(dataset)* 0.8)
+    # b = int(len(dataset) - a)
+    # train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b])
+    # dl = FastTensorDataLoader(dataset, batch_size=1, shuffle=False)
+    dl = DataLoader(train_data, batch_size=1, num_workers=16, sampler=ImbalancedDatasetSampler(train_data), pin_memory=True)
+    # print(len(dl))
+    # dl = DataLoader(dataset, batch_size=1, sampler=ImbalancedDatasetSampler(dataset), num_workers=5)
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    scaler = torch.cuda.amp.GradScaler()
+
+    model_ft = resnet50_baseline(pretrained=True)
+    for param in model_ft.parameters():
+        param.requires_grad = False
+    model_ft.to(device)
+    
+
+    
+    # data = DataLoader(dataset, batch_size=1)
+
+    # print(len(dataset))
+    # # x = 0
+    #/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/RU0248_PASD_jke_PASD_20200201_195900_BIG.hdf5
+    c = 0
+    label_count = [0] *n_classes
+    # print(len(dl))
+    start = time.time()
+    for item in tqdm(dl): 
+
+        
+        # if c >= 10:
+        #     break
+        bag, label, (name, _, patient) = item
+        bag = bag.squeeze(0).float().to(device)
+        label = label.to(device)
+        with torch.cuda.amp.autocast():
+            output = model_ft(bag)
+        c += 1
+    end = time.time()
+
+    print('Bag Time: ', end-start)
+        # print(label)
+        # print(name)
+        # print(patient)
+    #     label_count[torch.argmax(label)] += 1
+        # print(name)
+        # if name == 'RU0248_PASD_jke_PASD_20200201_195900_BIG':
+        
+            # print(bag)
+            # print(label)
+        
+    #     # # print(bag.shape)
+    #     # if bag.shape[1] == 1:
+    #     #     print(name)
+    #     #     print(bag.shape)
+        # print(bag.shape)
+        
+    #     # out_dir = Path(output_path) / name
+    #     # os.makedirs(out_dir, exist_ok=True)
+
+    #     # # print(item[2])
+    #     # # print(len(item))
+    #     # # print(item[1])
+    #     # # print(data.shape)
+    #     # # data = data.squeeze()
+    #     # bag = item[0]
+    #     bag = bag.squeeze()
+    #     original = original.squeeze()
+        # output_path = Path(output_dir) / name
+        # output_path.mkdir(exist_ok=True)
+        # for i in range(bag.shape[0]):
+        #     img = bag[i, :, :, :]
+        #     img = img.squeeze()
+            
+        #     img = ((img-img.min())/(img.max() - img.min())) * 255
+        #     # print(img)
+        #     # print(img)
+        #     img = img.numpy().astype(np.uint8).transpose(1,2,0)
+
+            
+        #     img = Image.fromarray(img)
+        #     img = img.convert('RGB')
+        #     img.save(f'{output_path}/{i}.png')
+
+        # c += 1
+            
+    #         o_img = original[i,:,:,:]
+    #         o_img = o_img.squeeze()
+    #         print(o_img.shape)
+    #         o_img = ((o_img-o_img.min())/(o_img.max()-o_img.min()))*255
+    #         o_img = o_img.numpy().astype(np.uint8).transpose(1,2,0)
+    #         o_img = Image.fromarray(o_img)
+    #         o_img = o_img.convert('RGB')
+    #         o_img.save(f'{output_path}/{i}_original.png')
+        
+    #     break
+        # else: break
+        # print(data.shape)
+        # print(label)
+    # a = [torch.Tensor((3,256,256))]*3
+    # b = torch.stack(a)
+    # print(b)
+    # c = to_fixed_size_bag(b, 512)
+    # print(c)
\ No newline at end of file
diff --git a/code/datasets/dali_dataloader.py b/code/datasets/dali_dataloader.py
index 0851613c97ebc3705d324e03cb22ce6060e12315..a5efe96fc66f9940ed6ee387dc50d24d01a63cdd 100644
--- a/code/datasets/dali_dataloader.py
+++ b/code/datasets/dali_dataloader.py
@@ -1,139 +1,236 @@
+import nvidia.dali as dali
 from nvidia.dali import pipeline_def
 from nvidia.dali.pipeline import Pipeline
 import nvidia.dali.fn as fn
+# import nvidia.dali.fn.readers.file as file
+from nvidia.dali.fn.decoders import image, image_random_crop
+
 import nvidia.dali.types as types
+from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy
 
 from pathlib import Path
-import matplotlib.pyplot as plt
-import matplotlib.gridspec as gridspec
-import math
 import json
+from tqdm import tqdm
+from random import shuffle
 import numpy as np
-import cupy as cp
+from PIL import Image
 import torch
-import imageio
+import os
+
+sequence_length = 800
 
-batch_size = 10
-home = Path.cwd().parts[1]
-# image_filename = f"/{home}/ylan/data/DeepGraft/224_128um/Aachen_Biopsy_Slides/BLOCKS/"
+# Path to MNIST dataset
+# data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/training/')
+# file_path = f'/home/ylan/data/DeepGraft/224_128um_v2'
 
 class ExternalInputIterator(object):
-    def __init__(self, batch_size):
-        self.file_path = f"/{home}/ylan/data/DeepGraft/224_128um/"
-        # self.label_file = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
-        self.label_path = f'/{home}/ylan/DeepGraft/training_tables/split_Aachen_PAS_tcmr_viral.json'
+    def __init__(self, file_path, label_path, mode, n_classes, device_id, num_gpus, batch_size = 1, max_bag_size=sequence_length):
+
+        self.file_path = file_path
+        self.label_path = label_path
+        self.n_classes = n_classes
+        self.mode = mode
+        self.max_bag_size = max_bag_size
+        self.min_bag_size = 120
         self.batch_size = batch_size
-        
-        mode = 'test'
-        # with open(self.images_dir + "file_list.txt", 'r') as f:
-        #     self.files = [line.rstrip() for line in f if line != '']
-        self.slideLabelDict = {}
+
+        self.data_info = []
+        self.data_cache = {}
         self.files = []
+        self.slideLabelDict = {}
         self.empty_slides = []
+
+        home = Path.cwd().parts[1]
+        slide_patient_dict_path = f'/{home}/ylan/DeepGraft/training_tables/slide_patient_dict.json'
+        with open(slide_patient_dict_path, 'r') as f:
+            self.slidePatientDict = json.load(f)
+
         with open(self.label_path, 'r') as f:
             temp_slide_label_dict = json.load(f)[mode]
+            print(len(temp_slide_label_dict))
             for (x, y) in temp_slide_label_dict:
                 x = Path(x).stem 
-
                 # x_complete_path = Path(self.file_path)/Path(x)
                 for cohort in Path(self.file_path).iterdir():
                     x_complete_path = Path(self.file_path) / cohort / 'BLOCKS' / Path(x)
                     if x_complete_path.is_dir():
-                        if len(list(x_complete_path.iterdir())) > 50:
+                        if len(list(x_complete_path.iterdir())) > self.min_bag_size:
                         # print(x_complete_path)
-                            # self.slideLabelDict[x] = y
-                            self.files.append((x_complete_path, y))
+                            self.slideLabelDict[x] = y
+                            patient = self.slidePatientDict[x_complete_path.stem]
+                            self.files.append((x_complete_path, y, patient))
                         else: self.empty_slides.append(x_complete_path)
         
-        # shuffle(self.files)
+
+        # for slide_dir in tqdm(self.files):
+        #     # self._add_data_infos(str(slide_dir.resolve()), load_data, slide_patient_dict)
+        #     wsi_name = Path(slide_dir).stem
+        #     if wsi_name in self.slideLabelDict:
+        #         label = self.slideLabelDict[wsi_name]
+        #         patient = self.slidePatientDict[wsi_name]
+        #         idx = -1
+        #         self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'patient': patient,'cache_idx': idx})
+
+        self.dataset_len = len(self.files)
+
+        self.files = self.files[self.dataset_len * device_id//num_gpus:self.dataset_len*(device_id+1)//num_gpus]
+
+        self.n = len(self.files)
+
+        test_data_root = os.environ['DALI_EXTRA_PATH']
+        jpeg_file = os.path.join(test_data_root, 'db', 'single', 'jpeg', '510', 'ship-1083562_640.jpg')
 
     def __iter__(self):
         self.i = 0
-        self.n = len(self.files)
+        shuffle(self.files)
         return self
-
+    
     def __next__(self):
-        batch = []
+        batch = [] 
         labels = []
-        file_names = []
-
-        for _ in range(self.batch_size):
-            wsi_batch = []
-            wsi_filename, label = self.files[self.i]
-            # jpeg_filename, label = self.files[self.i].split(' ')
-            name = Path(wsi_filename).stem
-            file_names.append(name)
-            for img_path in Path(wsi_filename).iterdir():
-                # f = open(img, 'rb')
-
-                f = imageio.imread(img_path)
-                img = cp.asarray(f)
-                wsi_batch.append(img.astype(cp.uint8))
-            wsi_batch = cp.stack(wsi_batch)
-            batch.append(wsi_batch)
-            labels.append(cp.array([label], dtype = cp.uint8))
+
+        if self.i >=self.n:
+            self.__iter__()
+            raise StopIteration
+        
+        # for _ in range(self.batch_size):
+        wsi_path, label, patient = self.files[self.i]
+        wsi_batch = []
+        for tile_path in Path(wsi_path).iterdir():
+            np_img = np.fromfile(tile_path, dtype=np.uint8)
+
+            batch.append(np_img)
+
+        # test_data_root = os.environ['DALI_EXTRA_PATH']
+        # jpeg_file = os.path.join(test_data_root, 'db', 'single', 'jpeg', '510', 'ship-1083562_640.jpg')
+        # wsi_batch = [np.fromfile(jpeg_file, dtype=np.uint8) for _ in range(sequence_length)]
+
+            # np_img = np.asarray(Image.open(tile_path)).astype(np.uint8)
+            # print(np_img.shape)
+            # wsi_batch.append(np_img)
             
-            self.i = (self.i + 1) % self.n
+            # print(np_img)
+
+        
+        
+        # wsi_batch = np.stack(wsi_batch, axis=0) 
+        # # print(wsi_batch.shape)
+        # print(wsi_batch)
+        # print(len(wsi_batch))
+        # if len(wsi_batch) > self.max_bag_size:
+        wsi_batch, _ = self.to_fixed_size_bag(batch, self.max_bag_size)
+        # batch.append(wsi_batch)
+        batch = wsi_batch
+        batch.append(torch.tensor([int(label)], dtype=torch.uint8))    
+        self.i += 1
+        # for i in range(len(batch)):
+        #     print(batch[i].shape)
+        #     print(labels[i])
         # print(batch)
-        # print(labels)
-        return (batch, labels)
-
-
-eii = ExternalInputIterator(batch_size=10)
-
-@pipeline_def()
-def hsv_pipeline(device, hue, saturation, value):
-    # files, labels = fn.readers.file(file_root=image_filename)
-    files, labels = fn.external_source(source=eii, num_outputs=2, dtype=types.UINT8)
-    images = fn.decoders.image(files, device = 'cpu' if device == 'cpu' else 'mixed')
-    converted = fn.hsv(images, hue=hue, saturation=saturation, value=value)
-    return images, converted
-
-def display(outputs, idx, columns=2, captions=None, cpu=True):
-    rows = int(math.ceil(len(outputs) / columns))
-    fig = plt.figure()
-    fig.set_size_inches(16, 6 * rows)
-    gs = gridspec.GridSpec(rows, columns)
-    row = 0
-    col = 0
-    for i, out in enumerate(outputs):
-        plt.subplot(gs[i])
-        plt.axis("off")
-        if captions is not None:
-            plt.title(captions[i])
-        plt.imshow(out.at(idx) if cpu else out.as_cpu().at(idx))
-
-# pipe_cpu = hsv_pipeline(device='cpu', hue=120, saturation=1, value=0.4, batch_size=batch_size, num_threads=1, device_id=0)
-# pipe_cpu.build()
-# cpu_output = pipe_cpu.run()
-
-# display(cpu_output, 3, captions=["Original", "Hue=120, Saturation=1, Value=0.4"])
-
-# pipe_gpu = hsv_pipeline(device='gpu', hue=120, saturation=2, value=1, batch_size=batch_size, num_threads=1, device_id=0)
-# pipe_gpu.build()
-# gpu_output = pipe_gpu.run()
-
-# display(gpu_output, 0, cpu=False, captions=["Original", "Hue=120, Saturation=2, Value=1"])
-
-
-pipe_gpu = Pipeline(batch_size=batch_size, num_threads=2, device_id=0)
-with pipe_gpu:
-    images, labels = fn.external_source(source=eii, num_outputs=2, device="gpu", dtype=types.UINT8)
-    enhance = fn.brightness_contrast(images, contrast=2)
-    pipe_gpu.set_outputs(enhance, labels)
-
-pipe_gpu.build()
-pipe_out_gpu = pipe_gpu.run()
-batch_gpu = pipe_out_gpu[0].as_cpu()
-labels_gpu = pipe_out_gpu[1].as_cpu()
-# file_names = pipe_out_gpu[2].as_cpu()
-
-output_path = f"/{home}/ylan/data/DeepGraft/224_128um/debug/augments/"
-output_path.mkdir(exist_ok=True)
-
-
-img = batch_gpu.at(2)
-print(img.shape)
-print(labels_gpu.at(2))
-plt.axis('off')
-plt.imsave(f'{output_path}/0.jpg', img[0, :, :, :])
\ No newline at end of file
+        return batch
+        # return (batch, labels)       
+    
+    def __len__(self):
+        return self.dataset_len
+
+    def to_fixed_size_bag(self, bag, bag_size):
+        
+        current_size = len(bag)
+        # print(bag)
+        
+        if current_size < bag_size:
+            zero_padded = [np.empty(1000, dtype=np.uint8)] * (bag_size - current_size)
+            bag_samples = bag + zero_padded
+            # while current_size < bag_size: 
+                # bag.append(np.empty(1, dtype=np.uint8)) 
+
+            # zero_padded = np.empty(5000) 
+        
+        else:
+
+            bag_samples = list(np.random.permutation(bag)[:bag_size])
+
+        # bag_samples = list(np.array(bag, dtype=object)[bag_idxs])
+
+        print(len(bag_samples))
+
+        return bag_samples, min(bag_size, len(bag))
+    next = __next__
+
+def ExternalSourcePipeline(batch_size, num_threads, device_id, external_data):
+    pipe = Pipeline(batch_size, num_threads, device_id)
+    with pipe:
+        *jpegs, label = fn.external_source(source=external_data, num_outputs=sequence_length+1, dtype=types.UINT8, batch=False)
+
+        images = fn.decoders.image(jpegs, device="mixed")
+        images = fn.resize(images, resize_x=224, resize_y=224)
+        
+        bag = fn.stack(*images)
+        # bag = fn.reshape(bag, layout='')
+        # output = fn.cast(bag, dtype=types.UINT8)
+
+        # output.append(images)
+        # print(output)
+        pipe.set_outputs(bag, label)
+    return pipe
+# @pipeline_def
+# def get_pipeline(file_root:str, size:int=224, validation_size: Optional[int]=256, random_shuffle: bool=False, training:bool=True, decoder_device:str= 'mixed', device:str='gpu'):
+
+#     images, labels = file(file_root=file_root, random_shuffle=random_shuffle, name='Reader')
+
+#     if training: 
+#         images = image_random_crop(images,
+#                                     random_area=[0.08, 1.0],
+#                                     random_aspect_ratio = [0.75, 1.3],
+#                                     device=decoder_device,
+#         )
+#         images = fn.resize(images, 
+#                         size=size, 
+#                         device=device)
+#         mirror = fn.random.coin_flip(
+# 			# probability refers to the probability of getting a 1
+# 			probability=0.5,
+# 		)
+#     else: 
+#         images = image(images, device=decoder_device)
+#         images = resize(images, size=validation_size, mode='not_smaller', device=device)
+#         mirror = False
+
+#     images = fn.crop_mirror_normalize(images, 
+#                                         crop=(size,size),
+#                                         mirror=mirror,
+#                                         mean=[0.485 * 255,0.456 * 255,0.406 * 255],
+# 		                                std=[0.229 * 255,0.224 * 255,0.225 * 255],
+# 		                                device=device,
+#         )
+#     if device == 'gpu':
+#         labels = labels.gpu()
+
+#     return images, labels
+
+# training_pipeline = get_pipeline(batch_size=1, num_threads=8, device_id=0, file_root=f'/home/ylan/data/DeepGraft/224_128um_v2', random_shuffle=True, training=True, size=224)
+# validation_pipeline = get_pipeline(batch_size=1, num_threads=8, device_id=0, file_root=f'/home/ylan/data/DeepGraft/224_128um_v2', random_shuffle=True, training=False, size=224)
+
+# training_pipeline.build()
+# validation_pipeline.build()
+
+# training_dataloader = DALIClassificationIterator(pipelines=training_pipeline, reader_name='Reader', last_batch_policy=LastBatchPolicy.PARTIAL, auto_reset=True)
+# validation_pipeline = DALIClassificationIterator(pipelines=validation_pipeline, reader_name='Reader', last_batch_policy=LastBatchPolicy.PARTIAL, auto_reset=True)
+
+if __name__ == '__main__':
+
+    home = Path.cwd().parts[1]
+    file_path = f'/{home}/ylan/data/DeepGraft/224_128um_v2'
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_norm_rest.json'
+    eii = ExternalInputIterator(file_path, label_path, mode="train", n_classes=2, device_id=0, num_gpus=1)
+
+    pipe = ExternalSourcePipeline(batch_size=1, num_threads=2, device_id = 0,
+                              external_data = eii)
+    pii = DALIClassificationIterator(pipe, last_batch_padded=True, last_batch_policy=LastBatchPolicy.PARTIAL)
+
+    for e in range(3):
+        for i, data in enumerate(pii):
+            # print(data)
+            print("epoch: {}, iter {}, real batch size: {}".format(e, i, len(data[0]["data"])))
+        pii.reset()
+            
diff --git a/code/datasets/data_interface.py b/code/datasets/data_interface.py
index efa104c4f3a1c0ce12ece5cfc285202c8f8f1825..fc4acabfae48ace4abb9833a2b67cbbecde6f2e8 100644
--- a/code/datasets/data_interface.py
+++ b/code/datasets/data_interface.py
@@ -2,8 +2,8 @@ import inspect # 查看python 类的参数和模块、函数代码
 import importlib # In order to dynamically import the library
 from typing import Optional
 import pytorch_lightning as pl
-from pytorch_lightning.loops.base import Loop
-from pytorch_lightning.loops.fit_loop import FitLoop
+# from pytorch_lightning.loops.base import Loop
+# from pytorch_lightning.loops.fit_loop import FitLoop
 
 from torch.utils.data import random_split, DataLoader
 from torch.utils.data.dataset import Dataset, Subset
@@ -12,8 +12,9 @@ from torchvision import transforms
 from .camel_dataloader import FeatureBagLoader
 from .custom_dataloader import HDF5MILDataloader
 from .custom_jpg_dataloader import JPGMILDataloader
+from .zarr_feature_dataloader import ZarrFeatureBagLoader
 from pathlib import Path
-from transformers import AutoFeatureExtractor
+# from transformers import AutoFeatureExtractor
 from torchsampler import ImbalancedDatasetSampler
 
 from abc import ABC, abstractclassmethod, abstractmethod
@@ -119,13 +120,13 @@ class DataInterface(pl.LightningDataModule):
 
 class MILDataModule(pl.LightningDataModule):
 
-    def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=50, n_classes=2, cache: bool=True, *args, **kwargs):
+    def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, use_features=False, *args, **kwargs):
         super().__init__()
         self.data_root = data_root
         self.label_path = label_path
         self.batch_size = batch_size
         self.num_workers = num_workers
-        self.image_size = 384
+        self.image_size = 224
         self.n_classes = n_classes
         self.target_number = 9
         self.mean_bag_length = 10
@@ -134,8 +135,15 @@ class MILDataModule(pl.LightningDataModule):
         self.num_bags_test = 50
         self.seed = 1
 
-        self.cache = True
+
+        self.cache = cache
         self.fe_transform = None
+        if not use_features: 
+            self.base_dataloader = JPGMILDataloader
+        else: 
+            
+            self.base_dataloader = ZarrFeatureBagLoader
+            self.cache = True
         
 
 
@@ -143,19 +151,22 @@ class MILDataModule(pl.LightningDataModule):
         home = Path.cwd().parts[1]
 
         if stage in (None, 'fit'):
-            dataset = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes)
+            dataset = self.base_dataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes, cache=self.cache)
+            # dataset = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes)
+            print(len(dataset))
             a = int(len(dataset)* 0.8)
             b = int(len(dataset) - a)
             self.train_data, self.valid_data = random_split(dataset, [a, b])
 
         if stage in (None, 'test'):
-            self.test_data = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, data_cache_size=1)
+            self.test_data = self.base_dataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, data_cache_size=1)
 
         return super().setup(stage=stage)
 
         
 
     def train_dataloader(self) -> DataLoader:
+        # return DataLoader(self.train_data,  batch_size = self.batch_size, num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True, 
         return DataLoader(self.train_data,  batch_size = self.batch_size, sampler=ImbalancedDatasetSampler(self.train_data), num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True, 
         #sampler=ImbalancedDatasetSampler(self.train_data)
     def val_dataloader(self) -> DataLoader:
diff --git a/code/datasets/fast_tensor_dl.py b/code/datasets/fast_tensor_dl.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4eff2611089bd62d6ccb8a579d857610d04890c
--- /dev/null
+++ b/code/datasets/fast_tensor_dl.py
@@ -0,0 +1,46 @@
+import torch
+
+class FastTensorDataLoader:
+    """
+    A DataLoader-like object for a set of tensors that can be much faster than
+    TensorDataset + DataLoader because dataloader grabs individual indices of
+    the dataset and calls cat (slow).
+    Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6
+    """
+    def __init__(self, *tensors, batch_size=32, shuffle=False):
+        """
+        Initialize a FastTensorDataLoader.
+        :param *tensors: tensors to store. Must have the same length @ dim 0.
+        :param batch_size: batch size to load.
+        :param shuffle: if True, shuffle the data *in-place* whenever an
+            iterator is created out of this object.
+        :returns: A FastTensorDataLoader.
+        """
+        assert all(t.shape[0] == tensors[0].shape[0] for t in tensors)
+        self.tensors = tensors
+
+        self.dataset_len = self.tensors[0].shape[0]
+        self.batch_size = batch_size
+        self.shuffle = shuffle
+
+        # Calculate # batches
+        n_batches, remainder = divmod(self.dataset_len, self.batch_size)
+        if remainder > 0:
+            n_batches += 1
+        self.n_batches = n_batches
+    def __iter__(self):
+        if self.shuffle:
+            r = torch.randperm(self.dataset_len)
+            self.tensors = [t[r] for t in self.tensors]
+        self.i = 0
+        return self
+
+    def __next__(self):
+        if self.i >= self.dataset_len:
+            raise StopIteration
+        batch = tuple(t[self.i:self.i+self.batch_size] for t in self.tensors)
+        self.i += self.batch_size
+        return batch
+
+    def __len__(self):
+        return self.n_batches
\ No newline at end of file
diff --git a/code/datasets/feature_extractor.py b/code/datasets/feature_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..0057c8effb518dd68d00f33035081578d9615166
--- /dev/null
+++ b/code/datasets/feature_extractor.py
@@ -0,0 +1,186 @@
+import numpy as np
+from pathlib import Path
+from PIL import Image
+from tqdm import tqdm
+import zarr
+from numcodecs import Blosc
+import torch
+import torch.nn as nn
+import ResNet as ResNet 
+import torchvision.transforms as transforms
+import torch.nn.functional as F
+import re
+from imgaug import augmenters as iaa
+
+def chunker(seq, size):
+    return (seq[pos:pos + size] for pos in range(0, len(seq), size))
+
+def get_coords(batch_names): #ToDO: Change function for precise coords
+    coords = []
+    
+    for tile_name in batch_names: 
+        # print(tile_name)
+        pos = re.findall(r'\((.*?)\)', tile_name)
+        x, y = pos[-1].split('_')
+        coords.append((int(x),int(y)))
+    return coords
+
+def augment(img):
+
+    sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1")
+    sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2")
+    sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3")
+    sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4")
+    sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5")
+
+    transforms = iaa.Sequential([
+        iaa.AddToHueAndSaturation(value=(-30, 30), name="MyHSV"), #13
+        sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")),
+        iaa.Fliplr(0.5, name="MyFlipLR"),
+        iaa.Flipud(0.5, name="MyFlipUD"),
+        sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")),
+        iaa.OneOf([
+            sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")),
+            sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")),
+            sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine"))
+        ], name="MyOneOf")
+    ])
+    seq_img_d = transforms.to_deterministic()
+    img = seq_img_d.augment_image(img)
+
+    return img
+
+
+if __name__ == '__main__':
+
+
+    home = Path.cwd().parts[1]
+    
+    data_root = Path(f'/{home}/ylan/data/DeepGraft/224_128um_v2')
+    # output_path = Path(f'/{home}/ylan/wsi_tools/debug/zarr')
+    cohorts = ['DEEPGRAFT_RA', 'Leuven'] #, 
+    # cohorts = ['Aachen_Biopsy_Slides', 'DEEPGRAFT_RU'] #, 
+    # cohorts = ['Aachen_Biopsy_Slides', 'DEEPGRAFT_RU', 'DEEPGRAFT_RA', 'Leuven'] #, 
+    compressor = Blosc(cname='blosclz', clevel=3)
+
+    val_transforms = transforms.Compose([
+            # 
+            transforms.ToTensor(),
+            transforms.Normalize(
+                mean=[0.485, 0.456, 0.406],
+                std=[0.229, 0.224, 0.225],
+            ),
+            # RangeNormalization(),
+        ])
+
+
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    scaler = torch.cuda.amp.GradScaler()
+    n_classes = 2
+    out_features = 1024
+    model_ft = ResNet.resnet50(num_classes=n_classes, mlp=False, two_branch=False, normlinear=True)
+    home = Path.cwd().parts[1]
+    model_ft.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=False)
+    for param in model_ft.parameters():
+        param.requires_grad = False
+    for m in model_ft.modules():
+        if isinstance(m, torch.nn.modules.batchnorm.BatchNorm2d):
+            m.eval()
+            m.weight.requires_grad = False
+            m.bias.requires_grad = False
+    model_ft.fc = nn.Linear(2048, out_features)
+    model_ft.eval()
+    model_ft.to(device)
+
+    batch_size = 100
+
+    for f in data_root.iterdir():
+        
+        if f.stem in cohorts:
+            print(f)
+            fe_path = f / 'FEATURES_RETCCL'
+            fe_path.mkdir(parents=True, exist_ok=True)
+            
+            # num_files = len(list((f / 'BLOCKS').iterdir()))
+            slide_list = []
+            for slide in (f / 'BLOCKS').iterdir():
+                if Path(slide).is_dir(): 
+                    if slide.suffix != '.zarr':
+                        slide_list.append(slide)
+
+            with tqdm(total=len(slide_list)) as pbar:
+                for slide in slide_list:
+                    # print('slide: ', slide)
+
+                    # run every slide 5 times for augments
+                    for n in range(5):
+
+
+                        output_path = fe_path / Path(str(slide.stem) + f'_aug{n}.zarr')
+                        if output_path.is_dir():
+                                pbar.update(1)
+                                print(output_path, ' skipped.')
+                                continue
+                        # else:
+                        output_array = []
+                        output_batch_names = []
+                        for tile_path_batch in chunker(list(slide.iterdir()), batch_size):
+                            batch_array = []
+                            batch_names = []
+                            for t in tile_path_batch:
+                                # for n in range(5):
+                                img = np.asarray(Image.open(str(t))).astype(np.uint8) #.astype(np.uint8)
+
+                                img = augment(img)
+
+                                img = val_transforms(img.copy()).to(device)
+                                batch_array.append(img)
+
+                                tile_name = t.stem
+                                batch_names.append(tile_name)
+                            if len(batch_array) == 0:
+                                continue
+                            else:
+                                batch_array = torch.stack(batch_array) 
+                                # with torch.cuda.amp.autocast():
+                                model_output = model_ft(batch_array).detach()
+                                output_array.append(model_output)
+                                output_batch_names += batch_names 
+                        if len(output_array) == 0:
+                            pbar.update(1)
+                            continue
+                        else:
+                            output_array = torch.cat(output_array, dim=0).cpu().numpy()
+                            output_batch_coords = get_coords(output_batch_names)
+                            # print(output_batch_coords)
+                            # z = zarr.group()
+                            # data = z.create_group('data')
+                            # tile_names = z.create_group('tile_names')
+                            # d1 = data.create_dataset('bag', shape=output_array.shape, chunks=True, compressor = compressor, synchronizer=zarr.ThreadSynchronizer(), dtype='i4')
+                            # d2 = tile_names.create_dataset(output_batch_coords, shape=[len(output_batch_coords), 2], chunks=True, compressor = compressor, synchronizer=zarr.ThreadSynchronizer(), dtype='i4')
+
+                            # z['data'] = zarr.array(output_array, chunks=True, compressor = compressor, synchronizer=zarr.ThreadSynchronizer(), dtype='float') # 1792 = 224*8
+                            # z['tile_names'] = zarr.array(output_batch_coords, chunks=True, compressor = compressor, synchronizer=zarr.ThreadSynchronizer(), dtype='int32') # 1792 = 224*8
+                            # z.save
+                            # print(z['data'])
+                            # print(z['data'][:])
+                            # print(z['tile_names'][:])
+                            # zarr.save(output_path, z)
+                            zarr.save_group(output_path, data=output_array, coords=output_batch_coords)
+
+                            # z_test = zarr.open(output_path, 'r')
+                            # print(z_test['data'][:])
+                            # # print(z_test.tree())
+                            
+                            # if np.all(output_array== z_test['data'][:]):
+                            #     print('data true')
+                            # if np.all(z['tile_names'][:] == z_test['tile_names'][:]):
+                            #     print('tile_names true')
+                            #     print(output_path ' ')
+                            # print(np.all(z[:] == z_test[:]))
+
+                        # np.save(f'{str(slide)}.npy', slide_np)
+                            pbar.update(1)
+            
+
+                  
\ No newline at end of file
diff --git a/code/datasets/feature_extractor_2.py b/code/datasets/feature_extractor_2.py
new file mode 100644
index 0000000000000000000000000000000000000000..0057c8effb518dd68d00f33035081578d9615166
--- /dev/null
+++ b/code/datasets/feature_extractor_2.py
@@ -0,0 +1,186 @@
+import numpy as np
+from pathlib import Path
+from PIL import Image
+from tqdm import tqdm
+import zarr
+from numcodecs import Blosc
+import torch
+import torch.nn as nn
+import ResNet as ResNet 
+import torchvision.transforms as transforms
+import torch.nn.functional as F
+import re
+from imgaug import augmenters as iaa
+
+def chunker(seq, size):
+    return (seq[pos:pos + size] for pos in range(0, len(seq), size))
+
+def get_coords(batch_names): #ToDO: Change function for precise coords
+    coords = []
+    
+    for tile_name in batch_names: 
+        # print(tile_name)
+        pos = re.findall(r'\((.*?)\)', tile_name)
+        x, y = pos[-1].split('_')
+        coords.append((int(x),int(y)))
+    return coords
+
+def augment(img):
+
+    sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1")
+    sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2")
+    sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3")
+    sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4")
+    sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5")
+
+    transforms = iaa.Sequential([
+        iaa.AddToHueAndSaturation(value=(-30, 30), name="MyHSV"), #13
+        sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")),
+        iaa.Fliplr(0.5, name="MyFlipLR"),
+        iaa.Flipud(0.5, name="MyFlipUD"),
+        sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")),
+        iaa.OneOf([
+            sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")),
+            sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")),
+            sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine"))
+        ], name="MyOneOf")
+    ])
+    seq_img_d = transforms.to_deterministic()
+    img = seq_img_d.augment_image(img)
+
+    return img
+
+
+if __name__ == '__main__':
+
+
+    home = Path.cwd().parts[1]
+    
+    data_root = Path(f'/{home}/ylan/data/DeepGraft/224_128um_v2')
+    # output_path = Path(f'/{home}/ylan/wsi_tools/debug/zarr')
+    cohorts = ['DEEPGRAFT_RA', 'Leuven'] #, 
+    # cohorts = ['Aachen_Biopsy_Slides', 'DEEPGRAFT_RU'] #, 
+    # cohorts = ['Aachen_Biopsy_Slides', 'DEEPGRAFT_RU', 'DEEPGRAFT_RA', 'Leuven'] #, 
+    compressor = Blosc(cname='blosclz', clevel=3)
+
+    val_transforms = transforms.Compose([
+            # 
+            transforms.ToTensor(),
+            transforms.Normalize(
+                mean=[0.485, 0.456, 0.406],
+                std=[0.229, 0.224, 0.225],
+            ),
+            # RangeNormalization(),
+        ])
+
+
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    scaler = torch.cuda.amp.GradScaler()
+    n_classes = 2
+    out_features = 1024
+    model_ft = ResNet.resnet50(num_classes=n_classes, mlp=False, two_branch=False, normlinear=True)
+    home = Path.cwd().parts[1]
+    model_ft.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=False)
+    for param in model_ft.parameters():
+        param.requires_grad = False
+    for m in model_ft.modules():
+        if isinstance(m, torch.nn.modules.batchnorm.BatchNorm2d):
+            m.eval()
+            m.weight.requires_grad = False
+            m.bias.requires_grad = False
+    model_ft.fc = nn.Linear(2048, out_features)
+    model_ft.eval()
+    model_ft.to(device)
+
+    batch_size = 100
+
+    for f in data_root.iterdir():
+        
+        if f.stem in cohorts:
+            print(f)
+            fe_path = f / 'FEATURES_RETCCL'
+            fe_path.mkdir(parents=True, exist_ok=True)
+            
+            # num_files = len(list((f / 'BLOCKS').iterdir()))
+            slide_list = []
+            for slide in (f / 'BLOCKS').iterdir():
+                if Path(slide).is_dir(): 
+                    if slide.suffix != '.zarr':
+                        slide_list.append(slide)
+
+            with tqdm(total=len(slide_list)) as pbar:
+                for slide in slide_list:
+                    # print('slide: ', slide)
+
+                    # run every slide 5 times for augments
+                    for n in range(5):
+
+
+                        output_path = fe_path / Path(str(slide.stem) + f'_aug{n}.zarr')
+                        if output_path.is_dir():
+                                pbar.update(1)
+                                print(output_path, ' skipped.')
+                                continue
+                        # else:
+                        output_array = []
+                        output_batch_names = []
+                        for tile_path_batch in chunker(list(slide.iterdir()), batch_size):
+                            batch_array = []
+                            batch_names = []
+                            for t in tile_path_batch:
+                                # for n in range(5):
+                                img = np.asarray(Image.open(str(t))).astype(np.uint8) #.astype(np.uint8)
+
+                                img = augment(img)
+
+                                img = val_transforms(img.copy()).to(device)
+                                batch_array.append(img)
+
+                                tile_name = t.stem
+                                batch_names.append(tile_name)
+                            if len(batch_array) == 0:
+                                continue
+                            else:
+                                batch_array = torch.stack(batch_array) 
+                                # with torch.cuda.amp.autocast():
+                                model_output = model_ft(batch_array).detach()
+                                output_array.append(model_output)
+                                output_batch_names += batch_names 
+                        if len(output_array) == 0:
+                            pbar.update(1)
+                            continue
+                        else:
+                            output_array = torch.cat(output_array, dim=0).cpu().numpy()
+                            output_batch_coords = get_coords(output_batch_names)
+                            # print(output_batch_coords)
+                            # z = zarr.group()
+                            # data = z.create_group('data')
+                            # tile_names = z.create_group('tile_names')
+                            # d1 = data.create_dataset('bag', shape=output_array.shape, chunks=True, compressor = compressor, synchronizer=zarr.ThreadSynchronizer(), dtype='i4')
+                            # d2 = tile_names.create_dataset(output_batch_coords, shape=[len(output_batch_coords), 2], chunks=True, compressor = compressor, synchronizer=zarr.ThreadSynchronizer(), dtype='i4')
+
+                            # z['data'] = zarr.array(output_array, chunks=True, compressor = compressor, synchronizer=zarr.ThreadSynchronizer(), dtype='float') # 1792 = 224*8
+                            # z['tile_names'] = zarr.array(output_batch_coords, chunks=True, compressor = compressor, synchronizer=zarr.ThreadSynchronizer(), dtype='int32') # 1792 = 224*8
+                            # z.save
+                            # print(z['data'])
+                            # print(z['data'][:])
+                            # print(z['tile_names'][:])
+                            # zarr.save(output_path, z)
+                            zarr.save_group(output_path, data=output_array, coords=output_batch_coords)
+
+                            # z_test = zarr.open(output_path, 'r')
+                            # print(z_test['data'][:])
+                            # # print(z_test.tree())
+                            
+                            # if np.all(output_array== z_test['data'][:]):
+                            #     print('data true')
+                            # if np.all(z['tile_names'][:] == z_test['tile_names'][:]):
+                            #     print('tile_names true')
+                            #     print(output_path ' ')
+                            # print(np.all(z[:] == z_test[:]))
+
+                        # np.save(f'{str(slide)}.npy', slide_np)
+                            pbar.update(1)
+            
+
+                  
\ No newline at end of file
diff --git a/code/datasets/simple_jpg_dataloader.py b/code/datasets/simple_jpg_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..332920a6b269ff1f4593e2f60e50ffa3109a4d12
--- /dev/null
+++ b/code/datasets/simple_jpg_dataloader.py
@@ -0,0 +1,322 @@
+# import pandas as pd
+
+import numpy as np
+import torch
+from torch import Tensor
+from torch.utils import data
+from torch.utils.data import random_split, DataLoader
+from torch.autograd import Variable
+from torch.nn.functional import one_hot
+import torch.utils.data as data_utils
+import torchvision.transforms as transforms
+import pandas as pd
+from sklearn.utils import shuffle
+from pathlib import Path
+from tqdm import tqdm
+from PIL import Image
+import cv2
+import json
+from imgaug import augmenters as iaa
+from torchsampler import ImbalancedDatasetSampler
+
+
+class FeatureBagLoader(data_utils.Dataset):
+    def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=100, max_bag_size=1000):
+        super().__init__()
+
+        self.data_info = []
+        self.data_cache = {}
+        self.slideLabelDict = {}
+        self.files = []
+        self.data_cache_size = data_cache_size
+        self.mode = mode
+        self.file_path = file_path
+        # self.csv_path = csv_path
+        self.label_path = label_path
+        self.n_classes = n_classes
+        self.max_bag_size = max_bag_size
+        self.min_bag_size = 120
+        self.empty_slides = []
+        self.corrupt_slides = []
+        self.cache = True
+        
+        # read labels and slide_path from csv
+        with open(self.label_path, 'r') as f:
+            temp_slide_label_dict = json.load(f)[mode]
+            print(len(temp_slide_label_dict))
+            for (x, y) in temp_slide_label_dict:
+                x = Path(x).stem 
+                # x_complete_path = Path(self.file_path)/Path(x)
+                for cohort in Path(self.file_path).iterdir():
+                    x_complete_path = Path(self.file_path) / cohort / 'BLOCKS' / Path(x)
+                    if x_complete_path.is_dir():
+                        if len(list(x_complete_path.iterdir())) > self.min_bag_size:
+                        # print(x_complete_path)
+                            self.slideLabelDict[x] = y
+                            self.files.append(x_complete_path)
+                        else: self.empty_slides.append(x_complete_path)
+        
+        home = Path.cwd().parts[1]
+        self.slide_patient_dict_path = f'/{home}/ylan/DeepGraft/training_tables/slide_patient_dict.json'
+        with open(self.slide_patient_dict_path, 'r') as f:
+            self.slide_patient_dict = json.load(f)
+
+        sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1")
+        sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2")
+        sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3")
+        sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4")
+        sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5")
+
+        self.train_transforms = iaa.Sequential([
+            iaa.AddToHueAndSaturation(value=(-30, 30), name="MyHSV"), #13
+            sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")),
+            iaa.Fliplr(0.5, name="MyFlipLR"),
+            iaa.Flipud(0.5, name="MyFlipUD"),
+            sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")),
+            # iaa.OneOf([
+            #     sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")),
+            #     sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")),
+            #     sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine"))
+            # ], name="MyOneOf")
+
+        ], name="MyAug")
+        self.val_transforms = transforms.Compose([
+            # 
+            transforms.ToTensor(),
+            transforms.Normalize(
+                mean=[0.485, 0.456, 0.406],
+                std=[0.229, 0.224, 0.225],
+            ),
+            # RangeNormalization(),
+        ])
+
+
+
+
+        self.features = []
+        self.labels = []
+        self.wsi_names = []
+        self.name_batches = []
+        self.patients = []
+        if self.cache:
+            if mode=='train':
+                seq_img_d = self.train_transforms.to_deterministic()
+                
+                # with tqdm(total=len(self.files)) as pbar:
+
+                for t in tqdm(self.files):
+                    batch, label, (wsi_name, name_batch, patient) = self.get_data(t)
+                    # print('label: ', label)
+                    out_batch = []
+                    for img in batch: 
+                        img = img.numpy().astype(np.uint8)
+                        img = seq_img_d.augment_image(img)
+                        img = self.val_transforms(img.copy())
+                        out_batch.append(img)
+                    # ft = ft.view(-1, 512)
+                    
+                    out_batch = torch.stack(out_batch)
+                    self.labels.append(label)
+                    self.features.append(out_batch)
+                    self.wsi_names.append(wsi_name)
+                    self.name_batches.append(name_batch)
+                    self.patients.append(patient)
+                        # pbar.update()
+            else: 
+                # with tqdm(total=len(self.file_path)) as pbar:
+                for t in tqdm(self.file_path):
+                    batch, label, (wsi_name, name_batch, patient) = self.get_data(t)
+                    out_batch = []
+                    for img in batch: 
+                        img = img.numpy().astype(np.uint8)
+                        img = self.val_transforms(img.copy())
+                        out_batch.append(img)
+                    # ft = ft.view(-1, 512)
+                    out_batch = torch.stack(out_batch)
+                    self.labels.append(label)
+                    self.features.append(out_batch)
+                    self.wsi_names.append(wsi_name)
+                    self.name_batches.append(name_batch)
+                    self.patients.append(patient)
+                        # pbar.update()
+        # print(self.get_bag_feats(self.train_path))
+        # self.r = np.random.RandomState(seed)
+
+        # self.num_in_train = 60000
+        # self.num_in_test = 10000
+
+        # if self.train:
+        #     self.train_bags_list, self.train_labels_list = self._create_bags()
+        # else:
+        #     self.test_bags_list, self.test_labels_list = self._create_bags()
+
+    def get_data(self, file_path):
+        
+        wsi_batch=[]
+        name_batch=[]
+        
+        for tile_path in Path(file_path).iterdir():
+            img = np.asarray(Image.open(tile_path)).astype(np.uint8)
+            img = torch.from_numpy(img)
+            wsi_batch.append(img)
+            name_batch.append(tile_path.stem)
+
+        wsi_batch = torch.stack(wsi_batch)
+
+        if wsi_batch.size(0) > self.max_bag_size:
+            wsi_batch, name_batch, _ = self.to_fixed_size_bag(wsi_batch, name_batch, self.max_bag_size)
+
+
+        wsi_batch, name_batch = self.data_dropout(wsi_batch, name_batch, drop_rate=0.1)
+
+        wsi_name = Path(file_path).stem
+        try:
+            label = self.slideLabelDict[wsi_name]
+        except KeyError:
+            print(f'{wsi_name} is not included in label file {self.label_path}')
+
+        try:
+            patient = self.slide_patient_dict[wsi_name]
+        except KeyError:
+            print(f'{wsi_name} is not included in label file {self.slide_patient_dict_path}')
+
+        return wsi_batch, label, (wsi_name, name_batch, patient)
+    
+    def get_labels(self, indices):
+        return [self.labels[i] for i in indices]
+
+
+    def to_fixed_size_bag(self, bag, names, bag_size: int = 512):
+
+        #duplicate bag instances unitl 
+
+        bag_idxs = torch.randperm(bag.shape[0])[:bag_size]
+        bag_samples = bag[bag_idxs]
+        name_samples = [names[i] for i in bag_idxs]
+        # bag_sample_names = [bag_names[i] for i in bag_idxs]
+        # q, r  = divmod(bag_size, bag_samples.shape[0])
+        # if q > 0:
+        #     bag_samples = torch.cat([bag_samples]*q, 0)
+
+        # self_padded = torch.cat([bag_samples, bag_samples[:r,:, :, :]])
+
+        # zero-pad if we don't have enough samples
+        # zero_padded = torch.cat((bag_samples,
+        #                         torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+
+        return bag_samples, name_samples, min(bag_size, len(bag))
+
+    def data_dropout(self, bag, batch_names, drop_rate):
+        bag_size = bag.shape[0]
+        bag_idxs = torch.randperm(bag_size)[:int(bag_size*(1-drop_rate))]
+        bag_samples = bag[bag_idxs]
+        name_samples = [batch_names[i] for i in bag_idxs]
+
+        return bag_samples, name_samples
+
+    def __len__(self):
+        return len(self.files)
+
+    def __getitem__(self, index):
+
+        if self.cache:
+            label = self.labels[index]
+            wsi = self.features[index]
+            label = Variable(Tensor(label))
+            wsi_name = self.wsi_names[index]
+            name_batch = self.name_batches[index]
+            patient = self.patients[index]
+            # feats = Variable(Tensor(feats))
+            return wsi, label, (wsi_name, name_batch, patient)
+        else:
+            if self.mode=='train':
+                batch, label, (wsi_name, name_batch, patient) = self.get_data(self.files[index])
+                label = Variable(Tensor(label))
+                # wsi = Variable(Tensor(wsi_batch))
+                out_batch = []
+                seq_img_d = self.train_transforms.to_deterministic()
+                for img in batch: 
+                    img = img.numpy().astype(np.uint8)
+                    img = seq_img_d.augment_image(img)
+                    img = self.val_transforms(img.copy())
+                    out_batch.append(img)
+                out_batch = torch.stack(out_batch)
+                # ft = ft.view(-1, 512)
+                
+            else:
+                batch, label, (wsi_name, name_batch, patient) = self.get_data(self.files[index])
+                label = Variable(Tensor(label))
+                out_batch = []
+                seq_img_d = self.train_transforms.to_deterministic()
+                for img in batch: 
+                    img = img.numpy().astype(np.uint8)
+                    img = self.val_transforms(img.copy())
+                    out_batch.append(img)
+                out_batch = torch.stack(out_batch)
+
+            return out_batch, label, (wsi_name, name_batch, patient)
+
+if __name__ == '__main__':
+    
+    from pathlib import Path
+    import os
+    import time
+    from fast_tensor_dl import FastTensorDataLoader
+    from custom_resnet50 import resnet50_baseline
+    
+    
+
+    home = Path.cwd().parts[1]
+    train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
+    data_root = f'/{home}/ylan/data/DeepGraft/224_128um_v2'
+    # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/split_debug.json'
+    # label_path = f'/{home}/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_norm_rest.json'
+    output_dir = f'/{data_root}/debug/augments'
+    os.makedirs(output_dir, exist_ok=True)
+
+    n_classes = 2
+
+    dataset = FeatureBagLoader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes)
+
+    # print(dataset.get_labels(0))
+    a = int(len(dataset)* 0.8)
+    b = int(len(dataset) - a)
+    train_data, valid_data = random_split(dataset, [a, b])
+    # print(dataset.dataset)
+    # a = int(len(dataset)* 0.8)
+    # b = int(len(dataset) - a)
+    # train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b])
+    # dl = FastTensorDataLoader(dataset, batch_size=1, shuffle=False)
+    dl = DataLoader(train_data, batch_size=1, num_workers=8, sampler=ImbalancedDatasetSampler(train_data), pin_memory=True)
+    # print(len(dl))
+    # dl = DataLoader(dataset, batch_size=1, sampler=ImbalancedDatasetSampler(dataset), num_workers=5)
+    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
+    scaler = torch.cuda.amp.GradScaler()
+
+    model_ft = resnet50_baseline(pretrained=True)
+    for param in model_ft.parameters():
+        param.requires_grad = False
+    model_ft.to(device)
+    
+    c = 0
+    label_count = [0] *n_classes
+    # print(len(dl))
+    start = time.time()
+    for item in tqdm(dl): 
+
+        # if c >= 10:
+        #     break
+        bag, label, (name, batch_names, patient) = item
+        # print(bag.shape)
+        # print(len(batch_names))
+        
+        bag = bag.squeeze(0).float().to(device)
+        label = label.to(device)
+        with torch.cuda.amp.autocast():
+            output = model_ft(bag)
+        c += 1
+    end = time.time()
+
+    print('Bag Time: ', end-start)
\ No newline at end of file
diff --git a/code/datasets/zarr_feature_dataloader.py b/code/datasets/zarr_feature_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..b244b87c83a128d9b40b604fef8df647690d482b
--- /dev/null
+++ b/code/datasets/zarr_feature_dataloader.py
@@ -0,0 +1,238 @@
+import pandas as pd
+
+import numpy as np
+import torch
+from torch import Tensor
+from torch.autograd import Variable
+from torch.nn.functional import one_hot
+from torch.utils import data
+from torch.utils.data import random_split, DataLoader
+from torchsampler import ImbalancedDatasetSampler
+from torchvision import datasets, transforms
+import pandas as pd
+from sklearn.utils import shuffle
+from pathlib import Path
+from tqdm import tqdm
+import zarr
+import json
+import cv2
+from PIL import Image
+# from models import TransMIL
+
+
+
+class ZarrFeatureBagLoader(data.Dataset):
+    def __init__(self, file_path, label_path, mode, n_classes, cache=False, data_cache_size=100, max_bag_size=1000):
+        super().__init__()
+
+        self.data_info = []
+        self.data_cache = {}
+        self.slideLabelDict = {}
+        self.files = []
+        self.data_cache_size = data_cache_size
+        self.mode = mode
+        self.file_path = file_path
+        # self.csv_path = csv_path
+        self.label_path = label_path
+        self.n_classes = n_classes
+        self.max_bag_size = max_bag_size
+        self.min_bag_size = 120
+        self.empty_slides = []
+        self.corrupt_slides = []
+        self.cache = True
+        
+        # read labels and slide_path from csv
+        with open(self.label_path, 'r') as f:
+            temp_slide_label_dict = json.load(f)[mode]
+            # print(len(temp_slide_label_dict))
+            for (x, y) in temp_slide_label_dict:
+                x = Path(x).stem
+                # x_complete_path = Path(self.file_path)/Path(x)
+                for cohort in Path(self.file_path).iterdir():
+                    x_complete_path = Path(self.file_path) / cohort / 'FEATURES_RETCCL' / (str(x) + '.zarr')
+                    if x_complete_path.is_dir():
+                        # if len(list(x_complete_path.iterdir())) > self.min_bag_size:
+                        # # print(x_complete_path)
+                        self.slideLabelDict[x] = y
+                        self.files.append(x_complete_path)
+        
+        home = Path.cwd().parts[1]
+        self.slide_patient_dict_path = f'/{home}/ylan/DeepGraft/training_tables/slide_patient_dict.json'
+        with open(self.slide_patient_dict_path, 'r') as f:
+            self.slide_patient_dict = json.load(f)
+
+        self.feature_bags = []
+        self.labels = []
+        self.wsi_names = []
+        self.name_batches = []
+        self.patients = []
+        if self.cache:
+            for t in tqdm(self.files):
+                # zarr_t = str(t) + '.zarr'
+                batch, label, (wsi_name, name_batch, patient) = self.get_data(t)
+
+                self.labels.append(label)
+                self.feature_bags.append(batch)
+                self.wsi_names.append(wsi_name)
+                self.name_batches.append(name_batch)
+                self.patients.append(patient)
+
+    def get_data(self, file_path, drop_rate=0.1):
+        
+        batch_names=[] #add function for name_batch read out
+
+        wsi_name = Path(file_path).stem
+        if wsi_name in self.slideLabelDict:
+            label = self.slideLabelDict[wsi_name]
+            
+            patient = self.slide_patient_dict[wsi_name]
+        z = zarr.open(file_path, 'r')
+        np_bag = np.array(z['data'][:])
+        # np_bag = np.array(zarr.open(file_path, 'r')).astype(np.uint8)
+        wsi_bag = torch.from_numpy(np_bag)
+        batch_coords = torch.from_numpy(np.array(z['coords'][:]))
+
+        # print(wsi_bag.shape)
+        bag_size = wsi_bag.shape[0]
+        
+        # random drop 
+        bag_idxs = torch.randperm(bag_size)[:int(bag_size*(1-drop_rate))]
+        wsi_bag = wsi_bag[bag_idxs, :]
+        batch_coords = batch_coords[bag_idxs]
+        # print(wsi_bag.shape)
+        # name_samples = [batch_names[i] for i in bag_idxs]
+        return wsi_bag, label, (wsi_name, batch_coords, patient)
+    
+    def get_labels(self, indices):
+        return [self.labels[i] for i in indices]
+
+
+    def to_fixed_size_bag(self, bag, names, bag_size: int = 512):
+
+        #duplicate bag instances unitl 
+
+        bag_idxs = torch.randperm(bag.shape[0])[:bag_size]
+        bag_samples = bag[bag_idxs]
+        name_samples = [names[i] for i in bag_idxs]
+        # bag_sample_names = [bag_names[i] for i in bag_idxs]
+        # q, r  = divmod(bag_size, bag_samples.shape[0])
+        # if q > 0:
+        #     bag_samples = torch.cat([bag_samples]*q, 0)
+
+        # self_padded = torch.cat([bag_samples, bag_samples[:r,:, :, :]])
+
+        # zero-pad if we don't have enough samples
+        # zero_padded = torch.cat((bag_samples,
+        #                         torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+
+        return bag_samples, name_samples, min(bag_size, len(bag))
+
+    def data_dropout(self, bag, batch_names, drop_rate):
+        bag_size = bag.shape[0]
+        bag_idxs = torch.randperm(bag_size)[:int(bag_size*(1-drop_rate))]
+        bag_samples = bag[bag_idxs]
+        name_samples = [batch_names[i] for i in bag_idxs]
+
+        return bag_samples, name_samples
+
+    def __len__(self):
+        return len(self.files)
+
+    def __getitem__(self, index):
+
+        if self.cache:
+            label = self.labels[index]
+            wsi = self.feature_bags[index]
+            # label = Variable(Tensor(label))
+            label = torch.as_tensor(label)
+            label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+            wsi_name = self.wsi_names[index]
+            name_batch = self.name_batches[index]
+            patient = self.patients[index]
+
+            #random dropout
+            #shuffle
+
+            # feats = Variable(Tensor(feats))
+            return wsi, label, (wsi_name, name_batch, patient)
+        else:
+            t = self.files[index]
+            batch, label, (wsi_name, name_batch, patient) = self.get_data(t)
+
+                # self.labels.append(label)
+                # self.feature_bags.append(batch)
+                # self.wsi_names.append(wsi_name)
+                # self.name_batches.append(name_batch)
+                # self.patients.append(patient)
+
+            return batch, label, (wsi_name, name_batch, patient)
+
+if __name__ == '__main__':
+    
+    from pathlib import Path
+    import os
+    import time
+    # from fast_tensor_dl import FastTensorDataLoader
+    # from custom_resnet50 import resnet50_baseline
+    
+    
+
+    home = Path.cwd().parts[1]
+    train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
+    data_root = f'/{home}/ylan/data/DeepGraft/224_128um_v2'
+    # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/split_debug.json'
+    # label_path = f'/{home}/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_norm_rest.json'
+    output_dir = f'/{data_root}/debug/augments'
+    os.makedirs(output_dir, exist_ok=True)
+
+    n_classes = 2
+
+    dataset = ZarrFeatureBagLoader(data_root, label_path=label_path, mode='train', cache=True, n_classes=n_classes)
+
+    # print(dataset.get_labels(0))
+    a = int(len(dataset)* 0.8)
+    b = int(len(dataset) - a)
+    train_data, valid_data = random_split(dataset, [a, b])
+    # print(dataset.dataset)
+    # a = int(len(dataset)* 0.8)
+    # b = int(len(dataset) - a)
+    # train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b])
+    # dl = FastTensorDataLoader(dataset, batch_size=1, shuffle=False)
+    dl = DataLoader(train_data, batch_size=1, num_workers=8, sampler=ImbalancedDatasetSampler(train_data), pin_memory=True)
+    # print(len(dl))
+    # dl = DataLoader(dataset, batch_size=1, sampler=ImbalancedDatasetSampler(dataset), num_workers=5)
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    scaler = torch.cuda.amp.GradScaler()
+
+    # model_ft = resnet50_baseline(pretrained=True)
+    # for param in model_ft.parameters():
+    #     param.requires_grad = False
+    # model_ft.to(device)
+    model = TransMIL(n_classes=n_classes).to(device)
+    
+    c = 0
+    label_count = [0] *n_classes
+    epochs = 1
+    # print(len(dl))
+    # start = time.time()
+    for i in range(epochs):
+        start = time.time()
+        for item in tqdm(dl): 
+
+            # if c >= 10:
+            #     break
+            bag, label, (name, batch_names, patient) = item
+            # print(bag.shape)
+            # print(len(batch_names))
+            print(label)
+            print(batch_names)
+            bag = bag.float().to(device)
+            # print(bag.shape)
+            # label = label.to(device)
+            with torch.cuda.amp.autocast():
+                output = model(bag)
+            # c += 1
+        end = time.time()
+        print('Bag Time: ', end-start)
\ No newline at end of file
diff --git a/code/label_map.json b/code/label_map.json
new file mode 100644
index 0000000000000000000000000000000000000000..d77f55d8bc082c1fe7b3fddf2ad348167861f521
--- /dev/null
+++ b/code/label_map.json
@@ -0,0 +1 @@
+{"tcmr_viral":{"0": "TCMR", "1": "VIRAL"}, "no_viral": {"0": "STABLE", "1": "TCMR", "2": "ABMR", "3": "MIXED"}, "no_other":{"0": "STABLE", "1": "TCMR", "2": "ABMR", "3": "MIXED", "4": "VIRAL"}, "rejections": {"0": "TCMR", "1": "ABMR", "2": "MIXED"}, "all": {"0": "STABLE", "1": "TCMR", "2": "ABMR", "3": "MIXED", "4": "VIRAL", "5": "OTHER"}, "rej_rest": {"0": "REJECTION", "1": "REST"}, "norm_rest": {"0": "STABLE", "1": "DISEASE"}, "norm_rej_rest":{"0": "STABLE", "1": "REJECTION", "2": "REST"}}
\ No newline at end of file
diff --git a/code/lightning_logs/version_0/cm_test.png b/code/lightning_logs/version_0/cm_test.png
deleted file mode 100644
index 6fb672f0640e83968af9558a7d438cbf42b51815..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_0/cm_test.png and /dev/null differ
diff --git a/code/lightning_logs/version_0/events.out.tfevents.1657535217.dgx2.2080039.0 b/code/lightning_logs/version_0/events.out.tfevents.1657535217.dgx2.2080039.0
deleted file mode 100644
index 734f0a9372e6608b0ab5aed218222431f0044e31..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_0/events.out.tfevents.1657535217.dgx2.2080039.0 and /dev/null differ
diff --git a/code/lightning_logs/version_0/hparams.yaml b/code/lightning_logs/version_0/hparams.yaml
deleted file mode 100644
index de11b2861a631025df79cacf8d8a13415bbe1769..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_0/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/256_256um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/256_256um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_1/cm_test.png b/code/lightning_logs/version_1/cm_test.png
deleted file mode 100644
index e2ba8d1036ec9fa3e57078dd47b9a747a2809fc0..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_1/cm_test.png and /dev/null differ
diff --git a/code/lightning_logs/version_1/events.out.tfevents.1657535625.dgx2.2086189.0 b/code/lightning_logs/version_1/events.out.tfevents.1657535625.dgx2.2086189.0
deleted file mode 100644
index b11a9e985744cff77e2d36a3e726bb4b8b9ac89c..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_1/events.out.tfevents.1657535625.dgx2.2086189.0 and /dev/null differ
diff --git a/code/lightning_logs/version_1/hparams.yaml b/code/lightning_logs/version_1/hparams.yaml
deleted file mode 100644
index 3027219ffc70ee8a4f037c88505b459131ec11e4..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_1/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/256_256um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/256_256um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_10/events.out.tfevents.1657546166.dgx1.47613.0 b/code/lightning_logs/version_10/events.out.tfevents.1657546166.dgx1.47613.0
deleted file mode 100644
index afa287ceab85beb3b739a44d7e230d7ada1d726f..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_10/events.out.tfevents.1657546166.dgx1.47613.0 and /dev/null differ
diff --git a/code/lightning_logs/version_10/hparams.yaml b/code/lightning_logs/version_10/hparams.yaml
deleted file mode 100644
index 385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_10/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - train
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: train
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_11/events.out.tfevents.1657546322.dgx1.48740.0 b/code/lightning_logs/version_11/events.out.tfevents.1657546322.dgx1.48740.0
deleted file mode 100644
index 328e3163a4bab058ba6a473d8d58e58f2fb51ac5..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_11/events.out.tfevents.1657546322.dgx1.48740.0 and /dev/null differ
diff --git a/code/lightning_logs/version_11/hparams.yaml b/code/lightning_logs/version_11/hparams.yaml
deleted file mode 100644
index 385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_11/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - train
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: train
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_12/events.out.tfevents.1657546521.dgx1.50053.0 b/code/lightning_logs/version_12/events.out.tfevents.1657546521.dgx1.50053.0
deleted file mode 100644
index 692f51dfebf86309232688295f7c9f36fc3096a5..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_12/events.out.tfevents.1657546521.dgx1.50053.0 and /dev/null differ
diff --git a/code/lightning_logs/version_12/hparams.yaml b/code/lightning_logs/version_12/hparams.yaml
deleted file mode 100644
index 385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_12/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - train
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: train
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_13/events.out.tfevents.1657546918.dgx1.52290.0 b/code/lightning_logs/version_13/events.out.tfevents.1657546918.dgx1.52290.0
deleted file mode 100644
index db43abd93eaa392fca46ac8cac0cf6ea91076b76..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_13/events.out.tfevents.1657546918.dgx1.52290.0 and /dev/null differ
diff --git a/code/lightning_logs/version_13/hparams.yaml b/code/lightning_logs/version_13/hparams.yaml
deleted file mode 100644
index 385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_13/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - train
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: train
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_14/events.out.tfevents.1657546992.dgx1.53435.0 b/code/lightning_logs/version_14/events.out.tfevents.1657546992.dgx1.53435.0
deleted file mode 100644
index 3ac295e5023bf3d8ac7d7a060eb0aa861cfe9c80..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_14/events.out.tfevents.1657546992.dgx1.53435.0 and /dev/null differ
diff --git a/code/lightning_logs/version_14/hparams.yaml b/code/lightning_logs/version_14/hparams.yaml
deleted file mode 100644
index 385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_14/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - train
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: train
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_15/events.out.tfevents.1657547134.dgx1.54703.0 b/code/lightning_logs/version_15/events.out.tfevents.1657547134.dgx1.54703.0
deleted file mode 100644
index 797ee6e3689ccb965f4e97b808281e986ef758b0..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_15/events.out.tfevents.1657547134.dgx1.54703.0 and /dev/null differ
diff --git a/code/lightning_logs/version_15/hparams.yaml b/code/lightning_logs/version_15/hparams.yaml
deleted file mode 100644
index 385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_15/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - train
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: train
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_16/events.out.tfevents.1657547198.dgx1.55641.0 b/code/lightning_logs/version_16/events.out.tfevents.1657547198.dgx1.55641.0
deleted file mode 100644
index b4cf4197bf688807ead868952cb4b66fdee7b451..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_16/events.out.tfevents.1657547198.dgx1.55641.0 and /dev/null differ
diff --git a/code/lightning_logs/version_16/hparams.yaml b/code/lightning_logs/version_16/hparams.yaml
deleted file mode 100644
index 385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_16/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - train
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: train
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_17/events.out.tfevents.1657623153.dgx1.41577.0 b/code/lightning_logs/version_17/events.out.tfevents.1657623153.dgx1.41577.0
deleted file mode 100644
index 8236270b84bc8d5cbd60a4f1b08f8752fb734a6a..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_17/events.out.tfevents.1657623153.dgx1.41577.0 and /dev/null differ
diff --git a/code/lightning_logs/version_17/hparams.yaml b/code/lightning_logs/version_17/hparams.yaml
deleted file mode 100644
index 399847bb2e810ee8b9dde0c0723e1615ac8d92dc..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_17/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 7
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_18/events.out.tfevents.1657624768.dgx1.72352.0 b/code/lightning_logs/version_18/events.out.tfevents.1657624768.dgx1.72352.0
deleted file mode 100644
index e2c425bbcbeaabdd74a5233ee4679ec8c0b467cf..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_18/events.out.tfevents.1657624768.dgx1.72352.0 and /dev/null differ
diff --git a/code/lightning_logs/version_18/hparams.yaml b/code/lightning_logs/version_18/hparams.yaml
deleted file mode 100644
index 5d3068796e8af6c6aacc06462015239f67239118..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_18/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_19/events.out.tfevents.1657624869.dgx1.76706.0 b/code/lightning_logs/version_19/events.out.tfevents.1657624869.dgx1.76706.0
deleted file mode 100644
index da36871f7274cacd3a80d00f386c7f1036749c2c..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_19/events.out.tfevents.1657624869.dgx1.76706.0 and /dev/null differ
diff --git a/code/lightning_logs/version_19/hparams.yaml b/code/lightning_logs/version_19/hparams.yaml
deleted file mode 100644
index 5d3068796e8af6c6aacc06462015239f67239118..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_19/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_2/events.out.tfevents.1657543650.dgx1.33429.0 b/code/lightning_logs/version_2/events.out.tfevents.1657543650.dgx1.33429.0
deleted file mode 100644
index 0e8f866c6b4bb783b058864f37241e2a7966e66a..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_2/events.out.tfevents.1657543650.dgx1.33429.0 and /dev/null differ
diff --git a/code/lightning_logs/version_2/hparams.yaml b/code/lightning_logs/version_2/hparams.yaml
deleted file mode 100644
index 3978f85d8fc605d6a8b9c42d1590620769fbdf5d..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_2/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 0
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - train
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: train
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/256_256um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/256_256um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_20/events.out.tfevents.1657625133.dgx1.5999.0 b/code/lightning_logs/version_20/events.out.tfevents.1657625133.dgx1.5999.0
deleted file mode 100644
index 30bdf9ae84865abecf1bbf3fe104f84a7e9f681e..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_20/events.out.tfevents.1657625133.dgx1.5999.0 and /dev/null differ
diff --git a/code/lightning_logs/version_20/hparams.yaml b/code/lightning_logs/version_20/hparams.yaml
deleted file mode 100644
index d8ca6a0997effb2e3843c368c93978c2621870cb..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_20/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_21/events.out.tfevents.1657625248.dgx1.11114.0 b/code/lightning_logs/version_21/events.out.tfevents.1657625248.dgx1.11114.0
deleted file mode 100644
index 580ae5449c6c891452ad6bb188dade7852059ad3..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_21/events.out.tfevents.1657625248.dgx1.11114.0 and /dev/null differ
diff --git a/code/lightning_logs/version_21/hparams.yaml b/code/lightning_logs/version_21/hparams.yaml
deleted file mode 100644
index d8ca6a0997effb2e3843c368c93978c2621870cb..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_21/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_22/events.out.tfevents.1657625470.dgx1.20071.0 b/code/lightning_logs/version_22/events.out.tfevents.1657625470.dgx1.20071.0
deleted file mode 100644
index 259e2f747c61500222f0b1faf00b2d6b3a41f41e..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_22/events.out.tfevents.1657625470.dgx1.20071.0 and /dev/null differ
diff --git a/code/lightning_logs/version_22/hparams.yaml b/code/lightning_logs/version_22/hparams.yaml
deleted file mode 100644
index d8ca6a0997effb2e3843c368c93978c2621870cb..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_22/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_23/events.out.tfevents.1657625510.dgx1.22295.0 b/code/lightning_logs/version_23/events.out.tfevents.1657625510.dgx1.22295.0
deleted file mode 100644
index c74c353abdb31d4bab1671520ec1793298923045..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_23/events.out.tfevents.1657625510.dgx1.22295.0 and /dev/null differ
diff --git a/code/lightning_logs/version_23/hparams.yaml b/code/lightning_logs/version_23/hparams.yaml
deleted file mode 100644
index d8ca6a0997effb2e3843c368c93978c2621870cb..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_23/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_24/events.out.tfevents.1657625570.dgx1.25099.0 b/code/lightning_logs/version_24/events.out.tfevents.1657625570.dgx1.25099.0
deleted file mode 100644
index d23806a37c9d31bcee1beeafa2689732bbe3ab9c..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_24/events.out.tfevents.1657625570.dgx1.25099.0 and /dev/null differ
diff --git a/code/lightning_logs/version_24/hparams.yaml b/code/lightning_logs/version_24/hparams.yaml
deleted file mode 100644
index d8ca6a0997effb2e3843c368c93978c2621870cb..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_24/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_25/events.out.tfevents.1657625613.dgx1.27343.0 b/code/lightning_logs/version_25/events.out.tfevents.1657625613.dgx1.27343.0
deleted file mode 100644
index a51495644c8c24447f5a9d0bce559fe1c0f7e857..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_25/events.out.tfevents.1657625613.dgx1.27343.0 and /dev/null differ
diff --git a/code/lightning_logs/version_25/hparams.yaml b/code/lightning_logs/version_25/hparams.yaml
deleted file mode 100644
index d8ca6a0997effb2e3843c368c93978c2621870cb..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_25/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_26/events.out.tfevents.1657625763.dgx1.33397.0 b/code/lightning_logs/version_26/events.out.tfevents.1657625763.dgx1.33397.0
deleted file mode 100644
index 19376c22a488f8e753dbdbff5f0e810aa1ae4057..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_26/events.out.tfevents.1657625763.dgx1.33397.0 and /dev/null differ
diff --git a/code/lightning_logs/version_26/hparams.yaml b/code/lightning_logs/version_26/hparams.yaml
deleted file mode 100644
index d8ca6a0997effb2e3843c368c93978c2621870cb..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_26/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_27/events.out.tfevents.1657625819.dgx1.36236.0 b/code/lightning_logs/version_27/events.out.tfevents.1657625819.dgx1.36236.0
deleted file mode 100644
index 63e1dd8a5203e7fb1ff819e29310446d3cd72067..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_27/events.out.tfevents.1657625819.dgx1.36236.0 and /dev/null differ
diff --git a/code/lightning_logs/version_27/hparams.yaml b/code/lightning_logs/version_27/hparams.yaml
deleted file mode 100644
index d8ca6a0997effb2e3843c368c93978c2621870cb..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_27/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_28/events.out.tfevents.1657625859.dgx1.38333.0 b/code/lightning_logs/version_28/events.out.tfevents.1657625859.dgx1.38333.0
deleted file mode 100644
index 6b6c0ae6852731226a72c9746e71c67c5b96361b..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_28/events.out.tfevents.1657625859.dgx1.38333.0 and /dev/null differ
diff --git a/code/lightning_logs/version_28/hparams.yaml b/code/lightning_logs/version_28/hparams.yaml
deleted file mode 100644
index d8ca6a0997effb2e3843c368c93978c2621870cb..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_28/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_29/events.out.tfevents.1657625903.dgx1.40628.0 b/code/lightning_logs/version_29/events.out.tfevents.1657625903.dgx1.40628.0
deleted file mode 100644
index 8783be4dc429493cadb5cc21bca2dc1a24e98d8f..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_29/events.out.tfevents.1657625903.dgx1.40628.0 and /dev/null differ
diff --git a/code/lightning_logs/version_29/hparams.yaml b/code/lightning_logs/version_29/hparams.yaml
deleted file mode 100644
index d8ca6a0997effb2e3843c368c93978c2621870cb..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_29/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_3/cm_test.png b/code/lightning_logs/version_3/cm_test.png
deleted file mode 100644
index 4a523c98c4bd81c8e446deba57c6e24579f4911a..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_3/cm_test.png and /dev/null differ
diff --git a/code/lightning_logs/version_3/events.out.tfevents.1657543830.dgx1.34643.0 b/code/lightning_logs/version_3/events.out.tfevents.1657543830.dgx1.34643.0
deleted file mode 100644
index 0726f25de4e54bfa2acdf91a638a4d86787f60ac..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_3/events.out.tfevents.1657543830.dgx1.34643.0 and /dev/null differ
diff --git a/code/lightning_logs/version_3/hparams.yaml b/code/lightning_logs/version_3/hparams.yaml
deleted file mode 100644
index a662047b861838dd7c0c87b5d45111c35a9a254d..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_3/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 0
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - train
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: train
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_30/events.out.tfevents.1657625960.dgx1.43463.0 b/code/lightning_logs/version_30/events.out.tfevents.1657625960.dgx1.43463.0
deleted file mode 100644
index 5a9ead248d726804faeead94babbb17b53639f97..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_30/events.out.tfevents.1657625960.dgx1.43463.0 and /dev/null differ
diff --git a/code/lightning_logs/version_30/hparams.yaml b/code/lightning_logs/version_30/hparams.yaml
deleted file mode 100644
index d8ca6a0997effb2e3843c368c93978c2621870cb..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_30/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_31/events.out.tfevents.1657626120.dgx1.48909.0 b/code/lightning_logs/version_31/events.out.tfevents.1657626120.dgx1.48909.0
deleted file mode 100644
index 125ad0c8b29dcda71447e15108aa5dc4d5d085aa..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_31/events.out.tfevents.1657626120.dgx1.48909.0 and /dev/null differ
diff --git a/code/lightning_logs/version_31/hparams.yaml b/code/lightning_logs/version_31/hparams.yaml
deleted file mode 100644
index d8ca6a0997effb2e3843c368c93978c2621870cb..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_31/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_32/events.out.tfevents.1657626628.dgx1.65460.0 b/code/lightning_logs/version_32/events.out.tfevents.1657626628.dgx1.65460.0
deleted file mode 100644
index 32f354dc772b3edd2b278e47b884bfb45af1ef7e..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_32/events.out.tfevents.1657626628.dgx1.65460.0 and /dev/null differ
diff --git a/code/lightning_logs/version_32/hparams.yaml b/code/lightning_logs/version_32/hparams.yaml
deleted file mode 100644
index d8ca6a0997effb2e3843c368c93978c2621870cb..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_32/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_33/events.out.tfevents.1657626794.dgx1.71384.0 b/code/lightning_logs/version_33/events.out.tfevents.1657626794.dgx1.71384.0
deleted file mode 100644
index 1c2bc9d73a885ec5d45e076338de61fb63e3475d..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_33/events.out.tfevents.1657626794.dgx1.71384.0 and /dev/null differ
diff --git a/code/lightning_logs/version_33/hparams.yaml b/code/lightning_logs/version_33/hparams.yaml
deleted file mode 100644
index d8ca6a0997effb2e3843c368c93978c2621870cb..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_33/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - test
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: test
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_4/events.out.tfevents.1657545281.dgx1.41143.0 b/code/lightning_logs/version_4/events.out.tfevents.1657545281.dgx1.41143.0
deleted file mode 100644
index 04c542e7c7c8be7ebf3a1c040e54855ccc321392..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_4/events.out.tfevents.1657545281.dgx1.41143.0 and /dev/null differ
diff --git a/code/lightning_logs/version_4/hparams.yaml b/code/lightning_logs/version_4/hparams.yaml
deleted file mode 100644
index a662047b861838dd7c0c87b5d45111c35a9a254d..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_4/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 0
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - train
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: train
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_5/events.out.tfevents.1657545407.dgx1.42159.0 b/code/lightning_logs/version_5/events.out.tfevents.1657545407.dgx1.42159.0
deleted file mode 100644
index de560d26010186b57ca59ac3412d2603077d9c84..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_5/events.out.tfevents.1657545407.dgx1.42159.0 and /dev/null differ
diff --git a/code/lightning_logs/version_5/hparams.yaml b/code/lightning_logs/version_5/hparams.yaml
deleted file mode 100644
index 385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_5/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - train
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: train
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_6/events.out.tfevents.1657545574.dgx1.43290.0 b/code/lightning_logs/version_6/events.out.tfevents.1657545574.dgx1.43290.0
deleted file mode 100644
index b96b3cc792298ad93aea695082529868f03f3aa0..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_6/events.out.tfevents.1657545574.dgx1.43290.0 and /dev/null differ
diff --git a/code/lightning_logs/version_6/hparams.yaml b/code/lightning_logs/version_6/hparams.yaml
deleted file mode 100644
index 385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_6/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - train
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: train
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_7/events.out.tfevents.1657545679.dgx1.44154.0 b/code/lightning_logs/version_7/events.out.tfevents.1657545679.dgx1.44154.0
deleted file mode 100644
index 7094035f824ab5b403ca76dc451d320b83f282b6..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_7/events.out.tfevents.1657545679.dgx1.44154.0 and /dev/null differ
diff --git a/code/lightning_logs/version_7/hparams.yaml b/code/lightning_logs/version_7/hparams.yaml
deleted file mode 100644
index 385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_7/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - train
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: train
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_8/events.out.tfevents.1657545837.dgx1.45339.0 b/code/lightning_logs/version_8/events.out.tfevents.1657545837.dgx1.45339.0
deleted file mode 100644
index feb362684898989fdf7cfbdf3ce96d9eeef16d25..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_8/events.out.tfevents.1657545837.dgx1.45339.0 and /dev/null differ
diff --git a/code/lightning_logs/version_8/hparams.yaml b/code/lightning_logs/version_8/hparams.yaml
deleted file mode 100644
index 385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_8/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - train
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: train
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/lightning_logs/version_9/events.out.tfevents.1657545985.dgx1.46385.0 b/code/lightning_logs/version_9/events.out.tfevents.1657545985.dgx1.46385.0
deleted file mode 100644
index 1f6b05a4e88ebf59fe40d3af43cf85a4991a21c0..0000000000000000000000000000000000000000
Binary files a/code/lightning_logs/version_9/events.out.tfevents.1657545985.dgx1.46385.0 and /dev/null differ
diff --git a/code/lightning_logs/version_9/hparams.yaml b/code/lightning_logs/version_9/hparams.yaml
deleted file mode 100644
index 385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f..0000000000000000000000000000000000000000
--- a/code/lightning_logs/version_9/hparams.yaml
+++ /dev/null
@@ -1,368 +0,0 @@
-backbone: resnet50
-cfg: &id010 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - General
-    - &id002 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - comment
-        - null
-      - !!python/tuple
-        - seed
-        - 2021
-      - !!python/tuple
-        - fp16
-        - true
-      - !!python/tuple
-        - amp_level
-        - O2
-      - !!python/tuple
-        - precision
-        - 16
-      - !!python/tuple
-        - multi_gpu_mode
-        - dp
-      - !!python/tuple
-        - gpus
-        - &id001
-          - 1
-      - !!python/tuple
-        - epochs
-        - 200
-      - !!python/tuple
-        - grad_acc
-        - 2
-      - !!python/tuple
-        - frozen_bn
-        - false
-      - !!python/tuple
-        - patience
-        - 50
-      - !!python/tuple
-        - server
-        - train
-      - !!python/tuple
-        - log_path
-        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
-      dictitems:
-        amp_level: O2
-        comment: null
-        epochs: 200
-        fp16: true
-        frozen_bn: false
-        gpus: *id001
-        grad_acc: 2
-        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
-        multi_gpu_mode: dp
-        patience: 50
-        precision: 16
-        seed: 2021
-        server: train
-      state: *id002
-  - !!python/tuple
-    - Data
-    - &id005 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - dataset_name
-        - custom
-      - !!python/tuple
-        - data_shuffle
-        - false
-      - !!python/tuple
-        - data_dir
-        - /home/ylan/data/DeepGraft/224_128um/
-      - !!python/tuple
-        - label_file
-        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-      - !!python/tuple
-        - fold
-        - 0
-      - !!python/tuple
-        - nfold
-        - 3
-      - !!python/tuple
-        - cross_val
-        - false
-      - !!python/tuple
-        - train_dataloader
-        - &id003 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id003
-      - !!python/tuple
-        - test_dataloader
-        - &id004 !!python/object/new:addict.addict.Dict
-          args:
-          - !!python/tuple
-            - batch_size
-            - 1
-          - !!python/tuple
-            - num_workers
-            - 8
-          dictitems:
-            batch_size: 1
-            num_workers: 8
-          state: *id004
-      - !!python/tuple
-        - bag_size
-        - 1024
-      dictitems:
-        bag_size: 1024
-        cross_val: false
-        data_dir: /home/ylan/data/DeepGraft/224_128um/
-        data_shuffle: false
-        dataset_name: custom
-        fold: 0
-        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
-        nfold: 3
-        test_dataloader: *id004
-        train_dataloader: *id003
-      state: *id005
-  - !!python/tuple
-    - Model
-    - &id006 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - name
-        - TransMIL
-      - !!python/tuple
-        - n_classes
-        - 2
-      - !!python/tuple
-        - backbone
-        - resnet50
-      - !!python/tuple
-        - in_features
-        - 512
-      - !!python/tuple
-        - out_features
-        - 512
-      dictitems:
-        backbone: resnet50
-        in_features: 512
-        n_classes: 2
-        name: TransMIL
-        out_features: 512
-      state: *id006
-  - !!python/tuple
-    - Optimizer
-    - &id007 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - opt
-        - lookahead_radam
-      - !!python/tuple
-        - lr
-        - 0.0002
-      - !!python/tuple
-        - opt_eps
-        - null
-      - !!python/tuple
-        - opt_betas
-        - null
-      - !!python/tuple
-        - momentum
-        - null
-      - !!python/tuple
-        - weight_decay
-        - 0.01
-      dictitems:
-        lr: 0.0002
-        momentum: null
-        opt: lookahead_radam
-        opt_betas: null
-        opt_eps: null
-        weight_decay: 0.01
-      state: *id007
-  - !!python/tuple
-    - Loss
-    - &id008 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - base_loss
-        - CrossEntropyLoss
-      dictitems:
-        base_loss: CrossEntropyLoss
-      state: *id008
-  - !!python/tuple
-    - config
-    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-  - !!python/tuple
-    - version
-    - 4
-  - !!python/tuple
-    - epoch
-    - '159'
-  - !!python/tuple
-    - log_path
-    - &id009 !!python/object/apply:pathlib.PosixPath
-      - /
-      - home
-      - ylan
-      - workspace
-      - TransMIL-DeepGraft
-      - logs
-      - DeepGraft
-      - TransMIL
-      - tcmr_viral
-      - _resnet50_CrossEntropyLoss
-      - lightning_logs
-      - version_4
-  dictitems:
-    Data: *id005
-    General: *id002
-    Loss: *id008
-    Model: *id006
-    Optimizer: *id007
-    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
-    epoch: '159'
-    log_path: *id009
-    version: 4
-  state: *id010
-data: &id013 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - dataset_name
-    - custom
-  - !!python/tuple
-    - data_shuffle
-    - false
-  - !!python/tuple
-    - data_dir
-    - /home/ylan/data/DeepGraft/224_128um/
-  - !!python/tuple
-    - label_file
-    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-  - !!python/tuple
-    - fold
-    - 0
-  - !!python/tuple
-    - nfold
-    - 3
-  - !!python/tuple
-    - cross_val
-    - false
-  - !!python/tuple
-    - train_dataloader
-    - &id011 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id011
-  - !!python/tuple
-    - test_dataloader
-    - &id012 !!python/object/new:addict.addict.Dict
-      args:
-      - !!python/tuple
-        - batch_size
-        - 1
-      - !!python/tuple
-        - num_workers
-        - 8
-      dictitems:
-        batch_size: 1
-        num_workers: 8
-      state: *id012
-  - !!python/tuple
-    - bag_size
-    - 1024
-  dictitems:
-    bag_size: 1024
-    cross_val: false
-    data_dir: /home/ylan/data/DeepGraft/224_128um/
-    data_shuffle: false
-    dataset_name: custom
-    fold: 0
-    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
-    nfold: 3
-    test_dataloader: *id012
-    train_dataloader: *id011
-  state: *id013
-log: !!python/object/apply:pathlib.PosixPath
-- /
-- home
-- ylan
-- workspace
-- TransMIL-DeepGraft
-- logs
-- DeepGraft
-- TransMIL
-- tcmr_viral
-- _resnet50_CrossEntropyLoss
-loss: &id014 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - base_loss
-    - CrossEntropyLoss
-  dictitems:
-    base_loss: CrossEntropyLoss
-  state: *id014
-model: &id015 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - name
-    - TransMIL
-  - !!python/tuple
-    - n_classes
-    - 2
-  - !!python/tuple
-    - backbone
-    - resnet50
-  - !!python/tuple
-    - in_features
-    - 512
-  - !!python/tuple
-    - out_features
-    - 512
-  dictitems:
-    backbone: resnet50
-    in_features: 512
-    n_classes: 2
-    name: TransMIL
-    out_features: 512
-  state: *id015
-optimizer: &id016 !!python/object/new:addict.addict.Dict
-  args:
-  - !!python/tuple
-    - opt
-    - lookahead_radam
-  - !!python/tuple
-    - lr
-    - 0.0002
-  - !!python/tuple
-    - opt_eps
-    - null
-  - !!python/tuple
-    - opt_betas
-    - null
-  - !!python/tuple
-    - momentum
-    - null
-  - !!python/tuple
-    - weight_decay
-    - 0.01
-  dictitems:
-    lr: 0.0002
-    momentum: null
-    opt: lookahead_radam
-    opt_betas: null
-    opt_eps: null
-    weight_decay: 0.01
-  state: *id016
diff --git a/code/models/AttMIL.py b/code/models/AttMIL.py
index d4a5938c1a27531cf992f9cad099ba445e7d093c..048fb0fc782aa856c1e7c6edcc850e0e3cb39552 100644
--- a/code/models/AttMIL.py
+++ b/code/models/AttMIL.py
@@ -76,4 +76,5 @@ class AttMIL(nn.Module): #gated attention
         M = torch.mm(A, H)  # KxL
         logits = self.classifier(M)
        
-        return logits, A
\ No newline at end of file
+        return logits, A
+
diff --git a/code/models/Chowder.py b/code/models/Chowder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8c4e186b3506a17720dbb08160548ae6e35e3e4
--- /dev/null
+++ b/code/models/Chowder.py
@@ -0,0 +1,59 @@
+import os
+import logging
+from argparse import ArgumentParser
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models as models
+
+import pytorch_lightning as pl
+from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
+
+'''
+Chowder implementation from Courtiol 2018(https://openreview.net/pdf?id=ryserbZR-) 
+'''
+
+
+class Chowder(nn.Module): 
+    def __init__(self, n_classes, features=512, r=5):
+        super(Chowder, self).__init__()
+        self.L = features
+        self.n_classes = n_classes
+        self.R = r
+
+
+        self.f1 = nn.Sequential(
+            nn.Conv1d(self.L, 1, 1),
+        )
+        self.f2 = nn.Sequential(
+            nn.Linear(r*2, 200),
+            nn.Linear(200, 100),
+            nn.Linear(100, self.n_classes),
+            # nn.Sigmoid()
+        )
+        
+    def forward(self, x):
+
+        x = x.float()
+        x = torch.transpose(x, 1, 2)
+
+        x = self.f1(x)
+        max_indices = torch.topk(x, self.R).values
+        min_indices = torch.topk(x, self.R, largest=False).values
+
+        cat_minmax = torch.cat((min_indices, max_indices), dim=2)
+
+        out = self.f2(cat_minmax).squeeze(0)
+        
+        return out, None
+
+if __name__ == '__main__':
+
+    data = torch.randn((1, 6000, 512)).cuda()
+    model = Chowder(n_classes=2).cuda()
+    print(model.eval())
+    logits, _ = model(data)
+
+    print(logits.shape)
\ No newline at end of file
diff --git a/code/models/ConvMixer.py b/code/models/ConvMixer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c9da60d10d18e00a83150cbb29e1c9cb28f1854
--- /dev/null
+++ b/code/models/ConvMixer.py
@@ -0,0 +1,182 @@
+import torch
+from torch import nn
+
+from labml_helpers.module import Module
+from labml_nn.utils import clone_module_list
+
+
+class ConvMixerLayer(Module):
+    """
+    <a id="ConvMixerLayer"></a>
+    ## ConvMixer layer
+    This is a single ConvMixer layer. The model will have a series of these.
+    """
+
+    def __init__(self, d_model: int, kernel_size: int):
+        """
+        * `d_model` is the number of channels in patch embeddings, $h$
+        * `kernel_size` is the size of the kernel of spatial convolution, $k$
+        """
+        super().__init__()
+        # Depth-wise convolution is separate convolution for each channel.
+        # We do this with a convolution layer with the number of groups equal to the number of channels.
+        # So that each channel is it's own group.
+        self.depth_wise_conv = nn.Conv2d(d_model, d_model,
+                                         kernel_size=kernel_size,
+                                         groups=d_model,
+                                         padding=(kernel_size - 1) // 2)
+        # Activation after depth-wise convolution
+        self.act1 = nn.GELU()
+        # Normalization after depth-wise convolution
+        self.norm1 = nn.BatchNorm2d(d_model)
+
+        # Point-wise convolution is a $1 \times 1$ convolution.
+        # i.e. a linear transformation of patch embeddings
+        self.point_wise_conv = nn.Conv2d(d_model, d_model, kernel_size=1)
+        # Activation after point-wise convolution
+        self.act2 = nn.GELU()
+        # Normalization after point-wise convolution
+        self.norm2 = nn.BatchNorm2d(d_model)
+
+    def forward(self, x: torch.Tensor):
+        # For the residual connection around the depth-wise convolution
+        residual = x
+
+        # Depth-wise convolution, activation and normalization
+        x = self.depth_wise_conv(x)
+        x = self.act1(x)
+        x = self.norm1(x)
+
+        # Add residual connection
+        x += residual
+
+        # Point-wise convolution, activation and normalization
+        x = self.point_wise_conv(x)
+        x = self.act2(x)
+        x = self.norm2(x)
+
+        #
+        return x
+
+
+class PatchEmbeddings(Module):
+    """
+    <a id="PatchEmbeddings"></a>
+    ## Get patch embeddings
+    This splits the image into patches of size $p \times p$ and gives an embedding for each patch.
+    """
+
+    def __init__(self, d_model: int, patch_size: int, in_channels: int):
+        """
+        * `d_model` is the number of channels in patch embeddings $h$
+        * `patch_size` is the size of the patch, $p$
+        * `in_channels` is the number of channels in the input image (3 for rgb)
+        """
+        super().__init__()
+
+        # We create a convolution layer with a kernel size and and stride length equal to patch size.
+        # This is equivalent to splitting the image into patches and doing a linear
+        # transformation on each patch.
+        self.conv = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)
+        # Activation function
+        self.act = nn.GELU()
+        # Batch normalization
+        self.norm = nn.BatchNorm2d(d_model)
+
+    def forward(self, x: torch.Tensor):
+        """
+        * `x` is the input image of shape `[batch_size, channels, height, width]`
+        """
+        # Apply convolution layer
+        x = self.conv(x)
+        # Activation and normalization
+        x = self.act(x)
+        x = self.norm(x)
+
+        #
+        return x
+
+
+class ClassificationHead(Module):
+    """
+    <a id="ClassificationHead"></a>
+    ## Classification Head
+    They do average pooling (taking the mean of all patch embeddings) and a final linear transformation
+    to predict the log-probabilities of the image classes.
+    """
+
+    def __init__(self, d_model: int, n_classes: int):
+        """
+        * `d_model` is the number of channels in patch embeddings, $h$
+        * `n_classes` is the number of classes in the classification task
+        """
+        super().__init__()
+        # Average Pool
+        self.pool = nn.AdaptiveAvgPool2d((1, 1))
+        # Linear layer
+        self.linear = nn.Linear(d_model, n_classes)
+
+    def forward(self, x: torch.Tensor):
+        # Average pooling
+        x = self.pool(x)
+        # Get the embedding, `x` will have shape `[batch_size, d_model, 1, 1]`
+        x = x[:, :, 0, 0]
+        # Linear layer
+        x = self.linear(x)
+
+        #
+        return x
+
+
+class ConvMixer(Module):
+    """
+    ## ConvMixer
+    This combines the patch embeddings block, a number of ConvMixer layers and a classification head.
+    """
+
+    def __init__(self, conv_mixer_layer: ConvMixerLayer, n_layers: int,
+                 patch_emb: PatchEmbeddings,
+                 classification: ClassificationHead):
+        """
+        * `conv_mixer_layer` is a copy of a single [ConvMixer layer](#ConvMixerLayer).
+         We make copies of it to make ConvMixer with `n_layers`.
+        * `n_layers` is the number of ConvMixer layers (or depth), $d$.
+        * `patch_emb` is the [patch embeddings layer](#PatchEmbeddings).
+        * `classification` is the [classification head](#ClassificationHead).
+        """
+        super().__init__()
+        # Patch embeddings
+        self.patch_emb = patch_emb
+        # Classification head
+        self.classification = classification
+        # Make copies of the [ConvMixer layer](#ConvMixerLayer)
+        self.conv_mixer_layers = clone_module_list(conv_mixer_layer, n_layers)
+
+    def forward(self, x: torch.Tensor):
+        """
+        * `x` is the input image of shape `[batch_size, channels, height, width]`
+        """
+        # Get patch embeddings. This gives a tensor of shape `[batch_size, d_model, height / patch_size, width / patch_size]`.
+        x = self.patch_emb(x)
+        print(x.shape)
+        # Pass through [ConvMixer layers](#ConvMixerLayer)
+        for layer in self.conv_mixer_layers:
+            x = layer(x)
+
+        # Classification head, to get logits
+        x = self.classification(x)
+
+        #
+        return x
+
+if __name__ == '__main__':
+
+
+    convmix = ConvMixerLayer(d_model=20, kernel_size=9)
+    patch_emb = PatchEmbeddings(d_model=20, patch_size=7, in_channels=3)
+    classification = ClassificationHead(d_model=20, n_classes=2)
+    model = ConvMixer(conv_mixer_layer=convmix, n_layers=20, patch_emb = patch_emb, classification=classification )
+    x = torch.randn([1,3,224,224])
+    
+    y = model(x)
+    print(y.shape)
\ No newline at end of file
diff --git a/code/models/MDMIL.py b/code/models/MDMIL.py
new file mode 100644
index 0000000000000000000000000000000000000000..db61da50e7fddd160d49a711ca006a63cf2c6ebc
--- /dev/null
+++ b/code/models/MDMIL.py
@@ -0,0 +1,141 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from nystrom_attention import NystromAttention
+
+
+class TransLayer(nn.Module):
+
+    def __init__(self, norm_layer=nn.LayerNorm, dim=512):
+        super().__init__()
+        self.norm = norm_layer(dim)
+        self.attn = NystromAttention(
+            dim = dim,
+            dim_head = dim//8,
+            heads = 8,
+            num_landmarks = dim//2,    # number of landmarks
+            pinv_iterations = 6,    # number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper
+            residual = True,         # whether to do an extra residual with the value or not. supposedly faster convergence if turned on
+            dropout=0.7 #0.1
+        )
+
+    def forward(self, x):
+        out, attn = self.attn(self.norm(x), return_attn=True)
+        x = x + out
+        # x = x + self.attn(self.norm(x))
+
+        return x, attn
+
+
+class PPEG(nn.Module):
+    def __init__(self, dim=512):
+        super(PPEG, self).__init__()
+        self.proj = nn.Conv2d(dim, dim, 7, 1, 7//2, groups=dim)
+        self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5//2, groups=dim)
+        self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3//2, groups=dim)
+
+    def forward(self, x, H, W):
+        B, _, C = x.shape
+        cls_token, feat_token = x[:, 0], x[:, 1:]
+        cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
+        x = self.proj(cnn_feat)+cnn_feat+self.proj1(cnn_feat)+self.proj2(cnn_feat)
+        x = x.flatten(2).transpose(1, 2)
+        x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
+        return x
+
+class IQGM(nn.Module):
+    def __init__(self, in_features, n_classes):
+        super(IQGM, self).__init__()
+
+        self.in_features = in_features
+        self.n_classes = self.n_classes
+        self.fc = nn.Linear(self.in_features, self.n_classes)
+
+    def forward(feats):
+        c = F.softmax(self.fc(feats))
+        _, m_indices = torch.sort(c, 0, descending=True)
+        m_feats = torch.index_select(feats, dim=0, index=m_indices[0,:]) #critical index?
+
+class MDMIL(nn.Module):
+    def __init__(self, n_classes):
+        super(MDMIL, self).__init__()
+        in_features = 1024
+        out_features = 512
+        self.pos_layer = PPEG(dim=out_features)
+        self._fc1 = nn.Sequential(nn.Linear(in_features, out_features), nn.GELU())
+        # self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU())
+        self.cls_token = nn.Parameter(torch.randn(1, 1, out_features))
+        self.n_classes = n_classes
+        self.layer1 = TransLayer(dim=out_features)
+        self.layer2 = TransLayer(dim=out_features)
+        self.norm = nn.LayerNorm(out_features)
+        self._fc2 = nn.Linear(out_features, self.n_classes)
+
+
+    def forward(self, x): #, **kwargs
+
+        h = x.float() #[B, n, 1024]
+        h = self._fc1(h) #[B, n, 512]
+        
+        # print('Feature Representation: ', h.shape)
+        #---->duplicate pad
+        H = h.shape[1]
+        _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
+        add_length = _H * _W - H
+        h = torch.cat([h, h[:,:add_length,:]],dim = 1) #[B, N, 512]
+        
+
+        #---->cls_token
+        B = h.shape[0]
+        cls_tokens = self.cls_token.expand(B, -1, -1).cuda()
+        h = torch.cat((cls_tokens, h), dim=1)
+
+
+        #---->Translayer x1
+        h, attn1 = self.layer1(h) #[B, N, 512]
+
+        # print('After first TransLayer: ', h.shape)
+
+        #---->PPEG
+        h = self.pos_layer(h, _H, _W) #[B, N, 512]
+        # print('After PPEG: ', h.shape)
+        
+        #---->Translayer x2
+        h, attn2 = self.layer2(h) #[B, N, 512]
+
+        # print('After second TransLayer: ', h.shape) #[1, 1025, 512] 1025 = cls_token + 1024
+        #---->cls_token
+        
+        h = self.norm(h)[:,0]
+
+        #---->predict
+        logits = self._fc2(h) #[B, n_classes]
+        return logits, attn2
+
+if __name__ == "__main__":
+    data = torch.randn((1, 6000, 512)).cuda()
+    model = TransMIL(n_classes=2).cuda()
+    print(model.eval())
+    logits, attn = model(data)
+    cls_attention = attn[:,:, 0, :6000]
+    values, indices = torch.max(cls_attention, 1)
+    mean = values.mean()
+    zeros = torch.zeros(values.shape).cuda()
+    filtered = torch.where(values > mean, values, zeros)
+    
+    # filter = values > values.mean()
+    # filtered_values = values[filter]
+    # values = np.where(values>values.mean(), values, 0)
+
+    print(filtered.shape)
+
+
+    # values = [v if v > values.mean().item() else 0 for v in values]
+    # print(values)
+    # print(len(values))
+
+    # logits = results_dict['logits']
+    # Y_prob = results_dict['Y_prob']
+    # Y_hat = results_dict['Y_hat']
+    # print(F.sigmoid(logits))
diff --git a/code/models/ResNet.py b/code/models/ResNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8fe70648f55b58eacc5da0578821d7842c563f3
--- /dev/null
+++ b/code/models/ResNet.py
@@ -0,0 +1,397 @@
+import torch
+import torch.nn as nn
+from torch.hub import load_state_dict_from_url
+import torch.nn.functional as F
+from torch.nn import Parameter
+
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+           'wide_resnet50_2', 'wide_resnet101_2']
+
+
+model_urls = {
+    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+                 base_width=64, dilation=1, norm_layer=None):
+        super(BasicBlock, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        if groups != 1 or base_width != 64:
+            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+        if dilation > 1:
+            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = norm_layer(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = norm_layer(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+                 base_width=64, dilation=1, norm_layer=None, momentum_bn=0.1):
+        super(Bottleneck, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        width = int(planes * (base_width / 64.)) * groups
+        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+        self.conv1 = conv1x1(inplanes, width)
+        self.bn1 = norm_layer(width, momentum=momentum_bn)
+        self.conv2 = conv3x3(width, width, stride, groups, dilation)
+        self.bn2 = norm_layer(width, momentum=momentum_bn)
+        self.conv3 = conv1x1(width, planes * self.expansion)
+        self.bn3 = norm_layer(planes * self.expansion, momentum=momentum_bn)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+class NormedLinear(nn.Module):
+
+    def __init__(self, in_features, out_features):
+        super(NormedLinear, self).__init__()
+        self.weight = Parameter(torch.Tensor(in_features, out_features))
+        self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
+
+    def forward(self, x):
+        out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0))
+        return out
+
+class ResNet(nn.Module):
+
+    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
+                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
+                 norm_layer=None, two_branch=False, mlp=False, normlinear=False,
+                 momentum_bn=0.1, attention=False, attention_layers=3, return_attn=False):
+        super(ResNet, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        self._norm_layer = norm_layer
+
+        self.inplanes = 64
+        self.dilation = 1
+        self.return_attn = return_attn
+        if replace_stride_with_dilation is None:
+            # each element in the tuple indicates if we should replace
+            # the 2x2 stride with a dilated convolution instead
+            replace_stride_with_dilation = [False, False, False]
+        if len(replace_stride_with_dilation) != 3:
+            raise ValueError("replace_stride_with_dilation should be None "
+                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+        self.groups = groups
+        self.base_width = width_per_group
+        self.two_branch = two_branch
+        self.momentum_bn = momentum_bn
+        self.mlp = mlp
+        linear = NormedLinear if normlinear else nn.Linear
+
+        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+                               bias=False)
+        self.bn1 = norm_layer(self.inplanes, momentum=momentum_bn)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+                                       dilate=replace_stride_with_dilation[0])
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+                                       dilate=replace_stride_with_dilation[1])
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+                                       dilate=replace_stride_with_dilation[2])
+
+        if attention:
+            self.att_branch = self._make_layer(block, 512, attention_layers, 1, attention=True)
+        else:
+            self.att_branch = None
+
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+
+        if self.mlp:
+            if self.two_branch:
+                self.fc = nn.Sequential(
+                    nn.Linear(512 * block.expansion, 512 * block.expansion),
+                    nn.ReLU()
+                ) 
+                self.instDis = linear(512 * block.expansion, num_classes)
+                self.groupDis = linear(512 * block.expansion, num_classes)
+            else:
+                self.fc = nn.Sequential(
+                    nn.Linear(512 * block.expansion, 512 * block.expansion),
+                    nn.ReLU(),
+                    linear(512 * block.expansion, num_classes)
+                ) 
+        else:
+            self.fc = nn.Linear(512 * block.expansion, num_classes)
+            if self.two_branch:
+                self.groupDis = nn.Linear(512 * block.expansion, num_classes)
+
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+        # Zero-initialize the last BN in each residual branch,
+        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, Bottleneck):
+                    nn.init.constant_(m.bn3.weight, 0)
+                elif isinstance(m, BasicBlock):
+                    nn.init.constant_(m.bn2.weight, 0)
+
+    def _make_layer(self, block, planes, blocks, stride=1, dilate=False, attention=False):
+        norm_layer = self._norm_layer
+        downsample = None
+        previous_dilation = self.dilation
+        if dilate:
+            self.dilation *= stride
+            stride = 1
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                norm_layer(planes * block.expansion, momentum=self.momentum_bn),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+                            self.base_width, previous_dilation, norm_layer, momentum_bn=self.momentum_bn))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(block(self.inplanes, planes, groups=self.groups,
+                                base_width=self.base_width, dilation=self.dilation,
+                                norm_layer=norm_layer, momentum_bn=self.momentum_bn))
+
+        if attention:
+            layers.append(nn.Sequential(
+                conv1x1(self.inplanes, 128),
+                nn.BatchNorm2d(128),
+                nn.ReLU(inplace=True),
+                conv1x1(128, 1),
+                nn.BatchNorm2d(1),
+                nn.Sigmoid()
+            ))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+        if self.att_branch is not None:
+            att_map = self.att_branch(x)
+            x = x + att_map * x
+
+        x = self.avgpool(x)
+        x = torch.flatten(x, 1)
+        if self.mlp and self.two_branch:
+            x = self.fc(x)
+            x1 = self.instDis(x)
+            x2 = self.groupDis(x)
+            return [x1, x2]
+        else:
+            x1 = self.fc(x)
+            if self.two_branch:
+                x2 = self.groupDis(x)
+                return [x1, x2]
+            return x1
+
+
+def _resnet(arch, block, layers, pretrained, progress, **kwargs):
+    model = ResNet(block, layers, **kwargs)
+    if pretrained:
+        state_dict = load_state_dict_from_url(model_urls[arch],
+                                              progress=progress)
+        model.load_state_dict(state_dict)
+    return model
+
+
+def resnet18(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-18 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
+                   **kwargs)
+
+
+def resnet34(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-34 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+                   **kwargs)
+
+
+def resnet50(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-50 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+                   **kwargs)
+
+
+def resnet101(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-101 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
+                   **kwargs)
+
+
+def resnet152(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-152 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
+                   **kwargs)
+
+
+def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
+    r"""ResNeXt-50 32x4d model from
+    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    kwargs['groups'] = 32
+    kwargs['width_per_group'] = 4
+    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
+                   pretrained, progress, **kwargs)
+
+
+def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
+    r"""ResNeXt-101 32x8d model from
+    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    kwargs['groups'] = 32
+    kwargs['width_per_group'] = 8
+    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
+                   pretrained, progress, **kwargs)
+
+
+def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
+    r"""Wide ResNet-50-2 model from
+    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
+    The model is the same as ResNet except for the bottleneck number of channels
+    which is twice larger in every block. The number of channels in outer 1x1
+    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+    channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    kwargs['width_per_group'] = 64 * 2
+    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
+                   pretrained, progress, **kwargs)
+
+
+def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
+    r"""Wide ResNet-101-2 model from
+    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
+    The model is the same as ResNet except for the bottleneck number of channels
+    which is twice larger in every block. The number of channels in outer 1x1
+    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+    channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    kwargs['width_per_group'] = 64 * 2
+    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
+                   pretrained, progress, **kwargs)
\ No newline at end of file
diff --git a/code/models/TransMIL.py b/code/models/TransMIL.py
index c78599d34d7faf2cc54fa182f5fb11ef63b13ea0..ddc126286835639ded83fd68054401c500237061 100755
--- a/code/models/TransMIL.py
+++ b/code/models/TransMIL.py
@@ -3,6 +3,17 @@ import torch.nn as nn
 import torch.nn.functional as F
 import numpy as np
 from nystrom_attention import NystromAttention
+import models.ResNet as ResNet
+from pathlib import Path
+
+try:
+    import apex
+    apex_available=True
+except ModuleNotFoundError:
+    # Error handling
+    apex_available = False
+    pass
+
 
 
 class TransLayer(nn.Module):
@@ -38,9 +49,9 @@ class PPEG(nn.Module):
     def forward(self, x, H, W):
         B, _, C = x.shape
         cls_token, feat_token = x[:, 0], x[:, 1:]
-        cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
+        cnn_feat = feat_token.transpose(1, 2).contiguous().view(B, C, H, W)
         x = self.proj(cnn_feat)+cnn_feat+self.proj1(cnn_feat)+self.proj2(cnn_feat)
-        x = x.flatten(2).transpose(1, 2)
+        x = x.flatten(2).transpose(1, 2).contiguous()
         x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
         return x
 
@@ -48,21 +59,37 @@ class PPEG(nn.Module):
 class TransMIL(nn.Module):
     def __init__(self, n_classes):
         super(TransMIL, self).__init__()
-        in_features = 512
-        out_features=512
+        in_features = 1024
+        out_features = 512
         self.pos_layer = PPEG(dim=out_features)
         self._fc1 = nn.Sequential(nn.Linear(in_features, out_features), nn.GELU())
         # self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU())
+        if apex_available: 
+            norm_layer = apex.normalization.FusedLayerNorm
+            
+        else:
+            norm_layer = nn.LayerNorm
         self.cls_token = nn.Parameter(torch.randn(1, 1, out_features))
         self.n_classes = n_classes
-        self.layer1 = TransLayer(dim=out_features)
-        self.layer2 = TransLayer(dim=out_features)
-        self.norm = nn.LayerNorm(out_features)
+        self.layer1 = TransLayer(norm_layer=norm_layer, dim=out_features)
+        self.layer2 = TransLayer(norm_layer=norm_layer, dim=out_features)
+        # self.norm = nn.LayerNorm(out_features)
+        self.norm = norm_layer(out_features)
         self._fc2 = nn.Linear(out_features, self.n_classes)
 
+        # self.model_ft = ResNet.resnet50(num_classes=self.n_classes, mlp=False, two_branch=False, normlinear=True).to(self.device)
+        # home = Path.cwd().parts[1]
+        # # self.model_ft.fc = nn.Identity()
+        # # self.model_ft.load_from_checkpoint(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth', strict=False)
+        # self.model_ft.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=False)
+        # for param in self.model_ft.parameters():
+        #     param.requires_grad = False
+        # self.model_ft.fc = nn.Linear(2048, self.in_features)
+
 
     def forward(self, x): #, **kwargs
 
+        # x = self.model_ft(x).unsqueeze(0)
         h = x.float() #[B, n, 1024]
         h = self._fc1(h) #[B, n, 512]
         
@@ -71,6 +98,8 @@ class TransMIL(nn.Module):
         H = h.shape[1]
         _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
         add_length = _H * _W - H
+
+        # print(h.shape)
         h = torch.cat([h, h[:,:add_length,:]],dim = 1) #[B, N, 512]
         
 
@@ -102,7 +131,7 @@ class TransMIL(nn.Module):
         return logits, attn2
 
 if __name__ == "__main__":
-    data = torch.randn((1, 6000, 512)).cuda()
+    data = torch.randn((1, 6000, 1024)).cuda()
     model = TransMIL(n_classes=2).cuda()
     print(model.eval())
     logits, attn = model(data)
diff --git a/code/models/TransformerMIL.py b/code/models/TransformerMIL.py
index ea249ba371fcfa7113625f6459ceb73421c13182..d5e6b89fa4088afc1b143362b627f33de7d57ce6 100644
--- a/code/models/TransformerMIL.py
+++ b/code/models/TransformerMIL.py
@@ -17,7 +17,7 @@ class TransLayer(nn.Module):
             num_landmarks = dim//2,    # number of landmarks
             pinv_iterations = 6,    # number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper
             residual = True,         # whether to do an extra residual with the value or not. supposedly faster convergence if turned on
-            dropout=0.25 #0.1
+            dropout=0.7 #0.1
         )
 
     def forward(self, x):
@@ -48,15 +48,15 @@ class PPEG(nn.Module):
 class TransformerMIL(nn.Module):
     def __init__(self, n_classes):
         super(TransformerMIL, self).__init__()
-        in_features = 512
+        in_features = 1024
         out_features = 512
-        self.pos_layer = PPEG(dim=out_features)
+        # self.pos_layer = PPEG(dim=out_features)
         self._fc1 = nn.Sequential(nn.Linear(in_features, out_features), nn.GELU())
         # self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU())
         self.cls_token = nn.Parameter(torch.randn(1, 1, out_features))
         self.n_classes = n_classes
         self.layer1 = TransLayer(dim=out_features)
-        self.layer2 = TransLayer(dim=out_features)
+        # self.layer2 = TransLayer(dim=out_features)
         self.norm = nn.LayerNorm(out_features)
         self._fc2 = nn.Linear(out_features, self.n_classes)
 
@@ -83,6 +83,8 @@ class TransformerMIL(nn.Module):
         #---->Translayer x1
         h, attn1 = self.layer1(h) #[B, N, 512]
 
+        
+
         # print('After first TransLayer: ', h.shape)
 
         #---->PPEG
@@ -99,6 +101,7 @@ class TransformerMIL(nn.Module):
 
         #---->predict
         logits = self._fc2(h) #[B, n_classes]
+        # return logits, attn2
         return logits, attn1
 
 if __name__ == "__main__":
diff --git a/code/models/__init__.py b/code/models/__init__.py
index 73aad9d74d278565976a1dd2c63c69dc7a0997ad..795017ec760368727809f128f00815f1b691ea00 100755
--- a/code/models/__init__.py
+++ b/code/models/__init__.py
@@ -1 +1,3 @@
 from .model_interface import ModelInterface
+# import ResNet as ResNet
+from .TransMIL import TransMIL
diff --git a/code/models/__pycache__/AttMIL.cpython-39.pyc b/code/models/__pycache__/AttMIL.cpython-39.pyc
index f0f6f23506b721b78a8172b4ed1b215c65ae45b0..550f6876773393094b93f1d14f352eb024d7917a 100644
Binary files a/code/models/__pycache__/AttMIL.cpython-39.pyc and b/code/models/__pycache__/AttMIL.cpython-39.pyc differ
diff --git a/code/models/__pycache__/Chowder.cpython-39.pyc b/code/models/__pycache__/Chowder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ec68d7a365117ff00442fa3ee31e78c94e6bf0b
Binary files /dev/null and b/code/models/__pycache__/Chowder.cpython-39.pyc differ
diff --git a/code/models/__pycache__/ResNet.cpython-39.pyc b/code/models/__pycache__/ResNet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..453943eef8a113608c389d6590ffeb844bbe73aa
Binary files /dev/null and b/code/models/__pycache__/ResNet.cpython-39.pyc differ
diff --git a/code/models/__pycache__/TransMIL.cpython-39.pyc b/code/models/__pycache__/TransMIL.cpython-39.pyc
index 38ed5a1f3fedfb29af2fc50c7909b312399d39d3..547c0a2e1594a0ef49a2c0482c066af1ca368fc9 100644
Binary files a/code/models/__pycache__/TransMIL.cpython-39.pyc and b/code/models/__pycache__/TransMIL.cpython-39.pyc differ
diff --git a/code/models/__pycache__/TransformerMIL.cpython-39.pyc b/code/models/__pycache__/TransformerMIL.cpython-39.pyc
index 999c95ee4caea337f6a77e6cc0eaad0fb3a6c90c..7ed17970f205f815c2b3f9dc1b6529a296ec9548 100644
Binary files a/code/models/__pycache__/TransformerMIL.cpython-39.pyc and b/code/models/__pycache__/TransformerMIL.cpython-39.pyc differ
diff --git a/code/models/__pycache__/__init__.cpython-39.pyc b/code/models/__pycache__/__init__.cpython-39.pyc
index 45428bd7fc4c40dc9fb1116df4a96aea2fb568aa..c7462b47578213c58c1f1bce6f2a74bbd2bceeca 100644
Binary files a/code/models/__pycache__/__init__.cpython-39.pyc and b/code/models/__pycache__/__init__.cpython-39.pyc differ
diff --git a/code/models/__pycache__/model_interface.cpython-39.pyc b/code/models/__pycache__/model_interface.cpython-39.pyc
index 565cfe10390ef1f7b43c48995fdfeb7ea63773a2..176ab832bfacbc7d53356fade7a947b9b78e4a7e 100644
Binary files a/code/models/__pycache__/model_interface.cpython-39.pyc and b/code/models/__pycache__/model_interface.cpython-39.pyc differ
diff --git a/code/models/ckpt/__init__.py b/code/models/ckpt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/code/models/ckpt/retccl_best_ckpt.pth b/code/models/ckpt/retccl_best_ckpt.pth
new file mode 100644
index 0000000000000000000000000000000000000000..f3b5568d87aa87cf21771297e4d0bff4ebafd892
Binary files /dev/null and b/code/models/ckpt/retccl_best_ckpt.pth differ
diff --git a/code/models/model_interface.py b/code/models/model_interface.py
index 586d80c41c22c5e5f3b1f48a95e130fdb0911571..9f4f282f67a84c5482a3f30fcb40b6d95bdcdd7d 100755
--- a/code/models/model_interface.py
+++ b/code/models/model_interface.py
@@ -15,9 +15,12 @@ from PIL import Image
 from MyOptimizer import create_optimizer
 from MyLoss import create_loss
 from utils.utils import cross_entropy_torch
+from utils.custom_resnet50 import resnet50_baseline
+
 from timm.loss import AsymmetricLossSingleLabel
 from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
-
+from libauc.losses import AUCMLoss, AUCM_MultiLabel, CompositionalAUCLoss
+from libauc.optimizers import PESG, PDSCA
 #---->
 import torch
 import torch.nn as nn
@@ -25,6 +28,7 @@ import torch.nn.functional as F
 import torchmetrics
 from torchmetrics.functional import stat_scores
 from torch import optim as optim
+
 # from sklearn.metrics import roc_curve, auc, roc_curve_score
 
 
@@ -41,6 +45,22 @@ from pytorch_grad_cam.utils.image import show_cam_on_image
 from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
 
 from captum.attr import LayerGradCam
+import models.ResNet as ResNet
+
+class FeatureExtractor(pl.LightningDataModule):
+    def __init__(self, model_name, n_classes):
+        self.n_classes = n_classes
+        
+        self.model_ft = ResNet.resnet50(num_classes=self.n_classes, mlp=False, two_branch=False, normlinear=True)
+        home = Path.cwd().parts[1]
+        self.model_ft.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=False)
+        # self.model_ft.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=False)
+        for param in self.model_ft.parameters():
+            param.requires_grad = False
+        self.model_ft.fc = nn.Linear(2048, self.out_features)
+
+    def forward(self,x):
+        return self.model_ft(x)
 
 class ModelInterface(pl.LightningModule):
 
@@ -48,17 +68,20 @@ class ModelInterface(pl.LightningModule):
     def __init__(self, model, loss, optimizer, **kargs):
         super(ModelInterface, self).__init__()
         self.save_hyperparameters()
+        self.n_classes = model.n_classes
         self.load_model()
-        self.loss = create_loss(loss)
+        self.loss = create_loss(loss, model.n_classes)
+        # self.loss = AUCM_MultiLabel(num_classes = model.n_classes, device=self.device)
         # self.asl = AsymmetricLossSingleLabel()
         # self.loss = LabelSmoothingCrossEntropy(smoothing=0.1)
         # self.loss = 
         # print(self.model)
         self.model_name = model.name
         
+        
         # self.ecam = EigenGradCAM(model = self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform)
         self.optimizer = optimizer
-        self.n_classes = model.n_classes
+        
         self.save_path = kargs['log']
         if Path(self.save_path).parts[3] == 'tcmr':
             temp = list(Path(self.save_path).parts)
@@ -66,15 +89,20 @@ class ModelInterface(pl.LightningModule):
             temp[3] = 'tcmr_viral'
             self.save_path = '/'.join(temp)
 
+        # if kargs['task']:
+        #     self.task = kargs['task']
+        self.task = Path(self.save_path).parts[3]
+
+
         #---->acc
         self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
         # print(self.experiment)
         #---->Metrics
         if self.n_classes > 2: 
-            self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted')
+            self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes)
             
             metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes,
-                                                                           average='micro'),
+                                                                           average='weighted'),
                                                      torchmetrics.CohenKappa(num_classes = self.n_classes),
                                                      torchmetrics.F1Score(num_classes = self.n_classes,
                                                                      average = 'macro'),
@@ -86,10 +114,11 @@ class ModelInterface(pl.LightningModule):
                                                                             num_classes = self.n_classes)])
                                                                             
         else : 
-            self.AUROC = torchmetrics.AUROC(num_classes=self.n_classes, average = 'weighted')
+            self.AUROC = torchmetrics.AUROC(num_classes=self.n_classes, average='weighted')
+            # self.AUROC = torchmetrics.AUROC(num_classes=self.n_classes, average = 'weighted')
 
             metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2,
-                                                                           average = 'micro'),
+                                                                           average = 'weighted'),
                                                      torchmetrics.CohenKappa(num_classes = 2),
                                                      torchmetrics.F1Score(num_classes = 2,
                                                                      average = 'macro'),
@@ -109,11 +138,14 @@ class ModelInterface(pl.LightningModule):
         self.count = 0
         self.backbone = kargs['backbone']
 
-        self.out_features = 512
-        if kargs['backbone'] == 'dino':
+        self.out_features = 1024
+
+        if self.backbone == 'features':
+            self.model_ft = None
+        elif self.backbone == 'dino':
             self.feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
             self.model_ft = ViTModel.from_pretrained('facebook/dino-vitb16')
-        elif kargs['backbone'] == 'resnet18':
+        elif self.backbone == 'resnet18':
             self.model_ft = models.resnet18(pretrained=True)
             # modules = list(resnet18.children())[:-1]
             for param in self.model_ft.parameters():
@@ -132,12 +164,33 @@ class ModelInterface(pl.LightningModule):
             #     nn.Linear(512, self.out_features),
             #     nn.GELU(),
             # )
-        elif kargs['backbone'] == 'resnet50':
-
-            self.model_ft = models.resnet50(pretrained=True)    
+        elif self.backbone == 'retccl':
+            # import models.ResNet as ResNet
+            self.model_ft = ResNet.resnet50(num_classes=self.n_classes, mlp=False, two_branch=False, normlinear=True)
+            home = Path.cwd().parts[1]
+            # pre_model = 
+            # self.model_ft.fc = nn.Identity()
+            # self.model_ft.load_from_checkpoint(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth', strict=False)
+            self.model_ft.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=False)
             for param in self.model_ft.parameters():
                 param.requires_grad = False
             self.model_ft.fc = nn.Linear(2048, self.out_features)
+            
+            # self.model_ft = FeatureExtractor('retccl', self.n_classes)
+
+
+        elif self.backbone == 'resnet50':
+            
+            self.model_ft = resnet50_baseline(pretrained=True)
+            for param in self.model_ft.parameters():
+                param.requires_grad = False
+
+            # self.model_ft = models.resnet50(pretrained=True)
+            # for param in self.model_ft.parameters():
+            #     param.requires_grad = False
+            # self.model_ft.fc = nn.Linear(2048, self.out_features)
+
+
             # modules = list(resnet50.children())[:-3]
             # res50 = nn.Sequential(
             #     *modules,     
@@ -150,7 +203,9 @@ class ModelInterface(pl.LightningModule):
             #     nn.Linear(1024, self.out_features),
             #     # nn.GELU()
             # )
-        elif kargs['backbone'] == 'efficientnet':
+        # elif kargs
+            
+        elif self.backbone == 'efficientnet':
             efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_widese_b0', pretrained=True)
             for param in efficientnet.parameters():
                 param.requires_grad = False
@@ -160,7 +215,7 @@ class ModelInterface(pl.LightningModule):
                 efficientnet,
                 nn.GELU(),
             )
-        elif kargs['backbone'] == 'simple': #mil-ab attention
+        elif self.backbone == 'simple': #mil-ab attention
             feature_extracting = False
             self.model_ft = nn.Sequential(
                 nn.Conv2d(3, 20, kernel_size=5),
@@ -176,15 +231,23 @@ class ModelInterface(pl.LightningModule):
         # print(self.model_ft[0].features[-1])
         # print(self.model_ft)
 
+    # def __build_
+
     def forward(self, x):
         # print(x.shape)
-        feats = self.model_ft(x).unsqueeze(0)
+        if self.model_ft:
+            feats = self.model_ft(x).unsqueeze(0)
+        else: 
+            feats = x.unsqueeze(0)
         return self.model(feats)
+        # return self.model(x)
 
     def step(self, input):
 
         input = input.squeeze(0).float()
-        logits, _ = self(input) 
+        logits, _ = self(input.contiguous()) 
+
+        
 
         Y_hat = torch.argmax(logits, dim=1)
         Y_prob = F.softmax(logits, dim=1)
@@ -196,7 +259,8 @@ class ModelInterface(pl.LightningModule):
         input, label, _= batch
 
         #random image dropout
-        # bag_size = 500
+
+        # bag_size = input.squeeze().shape[0] * 0.7
         # bag_idxs = torch.randperm(input.squeeze(0).shape[0])[:bag_size]
         # input = input.squeeze(0)[bag_idxs].unsqueeze(0)
 
@@ -218,20 +282,20 @@ class ModelInterface(pl.LightningModule):
             # Y = int(label[0])
         self.data[Y]["count"] += 1
         self.data[Y]["correct"] += (int(Y_hat) == Y)
-        self.log('loss', loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1)
+        self.log('loss', loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1, sync_dist=True)
 
-        if self.current_epoch % 10 == 0:
+        # if self.current_epoch % 10 == 0:
 
-            # images = input.squeeze()[:10, :, :, :]
-            # for i in range(10):
-            img = input.squeeze(0)[:10, :, :, :]
-            img = (img - torch.min(img)/(torch.max(img)-torch.min(img)))*255.0
+        #     # images = input.squeeze()[:10, :, :, :]
+        #     # for i in range(10):
+        #     img = input.squeeze(0)[:10, :, :, :]
+        #     img = (img - torch.min(img)/(torch.max(img)-torch.min(img)))*255.0
             
-            # mg = img.cpu().numpy()
-            grid = torchvision.utils.make_grid(img, normalize=True, value_range=(0, 255), scale_each=True)
-            # grid = img.detach().cpu().numpy()
-        # log input images 
-            self.loggers[0].experiment.add_image(f'{self.current_epoch}/input', grid)
+        #     # mg = img.cpu().numpy()
+        #     grid = torchvision.utils.make_grid(img, normalize=True, value_range=(0, 255), scale_each=True)
+        #     # grid = img.detach().cpu().numpy()
+        # # log input images 
+        #     self.loggers[0].experiment.add_image(f'{self.current_epoch}/input', grid)
 
 
         return {'loss': loss, 'Y_prob': Y_prob, 'Y_hat': Y_hat, 'label': Y} 
@@ -243,14 +307,16 @@ class ModelInterface(pl.LightningModule):
         # target = torch.stack([x['label'] for x in training_step_outputs], dim = 0)
         target = torch.stack([x['label'] for x in training_step_outputs])
         # target = torch.argmax(target, dim=1)
-        for c in range(self.n_classes):
-            count = self.data[c]["count"]
-            correct = self.data[c]["correct"]
-            if count == 0: 
-                acc = None
-            else:
-                acc = float(correct) / count
-            print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
+
+        if self.current_epoch % 5 == 0:
+            for c in range(self.n_classes):
+                count = self.data[c]["count"]
+                correct = self.data[c]["correct"]
+                if count == 0: 
+                    acc = None
+                else:
+                    acc = float(correct) / count
+                print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
         self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
 
         # print('max_probs: ', max_probs)
@@ -258,7 +324,7 @@ class ModelInterface(pl.LightningModule):
         if self.current_epoch % 10 == 0:
             self.log_confusion_matrix(max_probs, target, stage='train')
 
-        self.log('Train/auc', self.AUROC(probs, target), prog_bar=True, on_epoch=True, logger=True)
+        self.log('Train/auc', self.AUROC(probs, target), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
 
     def validation_step(self, batch, batch_idx):
 
@@ -271,8 +337,10 @@ class ModelInterface(pl.LightningModule):
         # Y = int(label[0][1])
         Y = torch.argmax(label)
 
+        # print(Y_hat)
         self.data[Y]["count"] += 1
-        self.data[Y]["correct"] += (Y_hat.item() == Y)
+        self.data[Y]["correct"] += (int(Y_hat) == Y)
+        # self.data[Y]["correct"] += (Y_hat.item() == Y)
 
         return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y}
 
@@ -284,20 +352,18 @@ class ModelInterface(pl.LightningModule):
         target = torch.stack([x['label'] for x in val_step_outputs])
         
         self.log_dict(self.valid_metrics(logits, target),
-                          on_epoch = True, logger = True)
+                          on_epoch = True, logger = True, sync_dist=True)
         
         #---->
         # logits = logits.long()
         # target = target.squeeze().long()
         # logits = logits.squeeze(0)
         if len(target.unique()) != 1:
-            self.log('val_auc', self.AUROC(probs, target), prog_bar=True, on_epoch=True, logger=True)
+            self.log('val_auc', self.AUROC(probs, target), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
         else:    
-            self.log('val_auc', 0.0, prog_bar=True, on_epoch=True, logger=True)
+            self.log('val_auc', 0.0, prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
 
-        
-
-        self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True)
+        self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
         
 
         precision, recall, thresholds = self.PRC(probs, target)
@@ -498,6 +564,8 @@ class ModelInterface(pl.LightningModule):
             print(f'{keys} = {values}')
             metrics[keys] = values.cpu().numpy()
         #---->acc log
+
+
         for c in range(self.n_classes):
             count = self.data[c]["count"]
             correct = self.data[c]["correct"]
@@ -530,8 +598,13 @@ class ModelInterface(pl.LightningModule):
     def configure_optimizers(self):
         # optimizer_ft = optim.Adam(self.model_ft.parameters(), lr=self.optimizer.lr*0.1)
         optimizer = create_optimizer(self.optimizer, self.model)
+        # optimizer = PESG(self.model, a=self.loss.a, b=self.loss.b, loss_fn=self.loss, lr=self.optimizer.lr, margin=1.0, epoch_decay=2e-3, weight_decay=1e-5, device=self.device)
+        # optimizer = PDSCA(self.model, loss_fn=self.loss, lr=self.optimizer.lr, margin=1.0, epoch_decay=2e-3, weight_decay=1e-5, device=self.device)
         return optimizer     
 
+    def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
+        optimizer.zero_grad(set_to_none=True)
+
     def reshape_transform(self, tensor):
         # print(tensor.shape)
         H = tensor.shape[1]
@@ -545,6 +618,7 @@ class ModelInterface(pl.LightningModule):
 
     def load_model(self):
         name = self.hparams.model.name
+        backbone = self.hparams.model.backbone
         # Change the `trans_unet.py` file name to `TransUnet` class name.
         # Please always name your model file name as `trans_unet.py` and
         # class name or funciton name corresponding `TransUnet`.
@@ -559,8 +633,26 @@ class ModelInterface(pl.LightningModule):
         except:
             raise ValueError('Invalid Module File Name or Invalid Class Name!')
         self.model = self.instancialize(Model)
+
+        # if backbone == 'retccl':
+
+        #     self.model_ft = ResNet.resnet50(num_classes=self.n_classes, mlp=False, two_branch=False, normlinear=True)
+        #     home = Path.cwd().parts[1]
+        #     # self.model_ft.fc = nn.Identity()
+        #     # self.model_ft.load_from_checkpoint(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth', strict=False)
+        #     self.model_ft.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=False)
+        #     for param in self.model_ft.parameters():
+        #         param.requires_grad = False
+        #     self.model_ft.fc = nn.Linear(2048, self.out_features)
+        
+        # elif backbone == 'resnet50':
+        #     self.model_ft = resnet50_baseline(pretrained=True)
+        #     for param in self.model_ft.parameters():
+        #         param.requires_grad = False
+
         pass
 
+
     def instancialize(self, Model, **other_args):
         """ Instancialize a model using the corresponding parameters
             from self.hparams dictionary. You can also input any args
@@ -573,6 +665,8 @@ class ModelInterface(pl.LightningModule):
             if arg in inkeys:
                 args1[arg] = getattr(self.hparams.model, arg)
         args1.update(other_args)
+
+
         return Model(**args1)
 
     def log_image(self, tensor, stage, name):
@@ -594,6 +688,8 @@ class ModelInterface(pl.LightningModule):
         else:
             fig_.savefig(f'{self.loggers[0].log_dir}/cm_test.png', dpi=400)
 
+        fig_.clf()
+
     def log_roc_curve(self, probs, target, stage):
 
         fpr_list, tpr_list, thresholds = self.ROC(probs, target)
diff --git a/code/test.ipynb b/code/test.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..4149f989b13e724c9ac6f209b4b3ecc887f80b97
--- /dev/null
+++ b/code/test.ipynb
@@ -0,0 +1,33 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "\n",
+    "a = [None] * 500\n",
+    "print(a.shape)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "b = np.fromfile(a)"
+   ]
+  }
+ ],
+ "metadata": {
+  "language_info": {
+   "name": "python"
+  },
+  "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
\ No newline at end of file
diff --git a/code/test_visualize.py b/code/test_visualize.py
index 82f6cb4c6dc9d14404aece06d93c729206b70ad6..5e4a5e76491d1725a78c90216f077c4eac32bd17 100644
--- a/code/test_visualize.py
+++ b/code/test_visualize.py
@@ -17,8 +17,9 @@ from utils.utils import *
 import pytorch_lightning as pl
 from pytorch_lightning import Trainer
 import torch
+import torch.nn as nn
 
-from pytorch_grad_cam import GradCAM, EigenGradCAM
+from pytorch_grad_cam import GradCAM, EigenGradCAM, EigenCAM
 from pytorch_grad_cam.utils.image import show_cam_on_image
 from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
 
@@ -26,28 +27,35 @@ import cv2
 from PIL import Image
 from matplotlib import pyplot as plt
 import pandas as pd
+import json
+import pprint
+
 
 #--->Setting parameters
 def make_parse():
     parser = argparse.ArgumentParser()
     parser.add_argument('--stage', default='test', type=str)
-    parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str)
+    parser.add_argument('--config', default='../DeepGraft/TransMIL.yaml',type=str)
     parser.add_argument('--version', default=0,type=int)
     parser.add_argument('--epoch', default='0',type=str)
     parser.add_argument('--gpus', default = 2, type=int)
     parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str)
     parser.add_argument('--fold', default = 0)
-    parser.add_argument('--bag_size', default = 1024, type=int)
+    parser.add_argument('--bag_size', default = 10000, type=int)
 
     args = parser.parse_args()
     return args
 
 class custom_test_module(ModelInterface):
 
+    # self.task = kargs['task']    
+    # self.task = 'tcmr_viral'
+
     def test_step(self, batch, batch_idx):
 
         torch.set_grad_enabled(True)
-        input_data, label, (wsi_name, batch_names) = batch
+        input_data, label, (wsi_name, batch_names, patient) = batch
+        patient = patient[0]
         wsi_name = wsi_name[0]
         label = label.float()
         # logits, Y_prob, Y_hat = self.step(data) 
@@ -60,6 +68,10 @@ class custom_test_module(ModelInterface):
         Y = torch.argmax(label)
         Y_hat = torch.argmax(logits, dim=1)
         Y_prob = F.softmax(logits, dim=1)
+
+        # print('Y_hat:', Y_hat)
+        # print('Y_prob:', Y_prob)
+
         
         #----> Get GradCam maps, map each instance to attention value, assemble, overlay on original WSI 
         if self.model_name == 'TransMIL':
@@ -67,6 +79,10 @@ class custom_test_module(ModelInterface):
             # target_layers = [self.model_ft[0].features[-1]] # 32x32
             self.cam = GradCAM(model=self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform) #, reshape_transform=self.reshape_transform
             # self.cam_ft = GradCAM(model=self.model, target_layers = target_layers_ft, use_cuda=True) #, reshape_transform=self.reshape_transform
+        elif self.model_name == 'TransformerMIL':
+            target_layers = [self.model.layer1.norm]
+            self.cam = EigenCAM(model=self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform)
+            # self.cam = GradCAM(model=self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform)
         else:
             target_layers = [self.model.attention_weights]
             self.cam = GradCAM(model = self.model, target_layers = target_layers, use_cuda=True)
@@ -74,14 +90,15 @@ class custom_test_module(ModelInterface):
         data_ft = self.model_ft(input_data).unsqueeze(0).float()
         instance_count = input_data.size(0)
         target = [ClassifierOutputTarget(Y)]
-        grayscale_cam = self.cam(input_tensor=data_ft, targets=target)
+        grayscale_cam = self.cam(input_tensor=data_ft, targets=target, eigen_smooth=True)
         grayscale_cam = torch.Tensor(grayscale_cam)[:instance_count, :] #.to(self.device)
 
         #----------------------------------------------------
         # Get Topk Tiles and Topk Patients
         #----------------------------------------------------
+        k = 10
         summed = torch.mean(grayscale_cam, dim=2)
-        topk_tiles, topk_indices = torch.topk(summed.squeeze(0), 5, dim=0)
+        topk_tiles, topk_indices = torch.topk(summed.squeeze(0), k, dim=0)
         topk_data = input_data[topk_indices].detach()
         
         #----------------------------------------------------
@@ -95,56 +112,158 @@ class custom_test_module(ModelInterface):
         # Tile Level Attention Maps
         #----------------------------------------------------
 
-        self.save_attention_map(wsi_name, input_data, batch_names, grayscale_cam, target=Y)
+        # print(input_data.shape)
+        # print(len(batch_names))
+        # if visualize:
+        #     self.save_attention_map(wsi_name, input_data, batch_names, grayscale_cam, target=Y)
+        # print('test_step_patient: ', patient)
 
-        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y, 'name': wsi_name, 'topk_data': topk_data} #
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y, 'name': wsi_name, 'patient': patient, 'topk_data': topk_data} #
         # return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': name} #, 'topk_data': topk_data
 
     def test_epoch_end(self, output_results):
 
+        k_patient = 1
+        k_slide = 1
+
+        pp = pprint.PrettyPrinter(indent=4)
+
+
         logits = torch.cat([x['logits'] for x in output_results], dim = 0)
         probs = torch.cat([x['Y_prob'] for x in output_results])
         max_probs = torch.stack([x['Y_hat'] for x in output_results])
         # target = torch.stack([x['label'] for x in output_results], dim = 0)
         target = torch.stack([x['label'] for x in output_results])
         # target = torch.argmax(target, dim=1)
-        patients = [x['name'] for x in output_results]
+        slide = [x['name'] for x in output_results]
+        patients = [x['patient'] for x in output_results]
         topk_tiles = [x['topk_data'] for x in output_results]
         #---->
 
-        auc = self.AUROC(probs, target)
+        if len(target.unique()) !=1:
+            auc = self.AUROC(probs, target)
+        else: auc = torch.tensor(0)
         metrics = self.test_metrics(logits , target)
 
 
         # metrics = self.test_metrics(max_probs.squeeze() , torch.argmax(target.squeeze(), dim=1))
         metrics['test_auc'] = auc
 
-        # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True)
+        # print(metrics)
+        np_metrics = {k: metrics[k].item() for k in metrics.keys()}
+        # print(np_metrics)
 
-        #---->get highest scoring patients for each class
-        # test_path = Path(self.save_path) / 'most_predictive' 
         
-        # Path.mkdir(output_path, exist_ok=True)
-        topk, topk_indices = torch.topk(probs.squeeze(0), 5, dim=0)
+        complete_patient_dict = {}
+        '''
+        Patient
+        -> slides:
+            
+            -> SlideName:
+                ->probs = [0.5, 0.5] 
+                ->topk = [10,3,224,224]
+        -> score = []
+        '''
+
+
+        for p, s, l, topkt in zip(patients, slide, probs, topk_tiles):
+            if p not in complete_patient_dict.keys():
+                complete_patient_dict[p] = {'slides':{}}
+            complete_patient_dict[p]['slides'][s] = {'probs': l, 'topk':topkt}
+
+        patient_list = []            
+        patient_score = []            
+        for p in complete_patient_dict.keys():
+            score = []
+            
+            for s in complete_patient_dict[p]['slides'].keys():
+                score.append(complete_patient_dict[p]['slides'][s]['probs'])
+            score = torch.mean(torch.stack(score), dim=0) #.cpu().detach().numpy()
+            complete_patient_dict[p]['score'] = score
+            print(p, score)
+            patient_list.append(p)    
+            patient_score.append(score)    
+
+        print(patient_list)
+        #topk patients: 
+
+
+        # task = 'tcmr_viral'
+        task = Path(self.save_path).parts[-5]
+        label_map_path = 'label_map.json'
+        with open(label_map_path, 'r') as f:
+            label_map = json.load(f)
+        
+        # topk_patients, topk_p_indices = torch.topk(score, 5, dim=0)
+
+        # print(probs.squeeze(0))
+        # topk, topk_indices = torch.topk(probs.squeeze(0), 5, dim=0) # topk = 
+        # print(topk)
+        
+        # topk_indices = topk_indices.transpose(0, 1)
+
+        output_dict = {}
+    
+
         for n in range(self.n_classes):
-            print('class: ', n)
+
+            class_name = f'{n}_{label_map[task][str(n)]}'
+
+            output_dict[class_name] = {}
+            # class_name = str(n)
+            print('class: ', class_name)
+            print(score)
+            _, topk_indices = torch.topk(score, k_patient, dim=0) # change to 3
+            print(topk_indices)
+
+            topk_patients = [patient_list[i] for i in topk_indices]
+
+            patient_top_slides = {} 
+            for p in topk_patients:
+                print(p)
+                output_dict[class_name][p] = {}
+                output_dict[class_name][p]['Patient_Score'] = complete_patient_dict[p]['score'].cpu().detach().numpy().tolist()
+
+                slides = list(complete_patient_dict[p]['slides'].keys())
+                slide_scores = [complete_patient_dict[p]['slides'][s]['probs'] for s in slides]
+                slide_scores = torch.stack(slide_scores)
+                print(slide_scores)
+                _, topk_slide_indices = torch.topk(slide_scores, k_slide, dim=0)
+                # topk_slide_indices = topk_slide_indices.squeeze(0)
+                print(topk_slide_indices[0])
+                topk_patient_slides = [slides[i] for i in topk_slide_indices[0]]
+                patient_top_slides[p] = topk_patient_slides
+
+                output_dict[class_name][p]['Top_Slides'] = [{slides[i]: {'Slide_Score': slide_scores[i].cpu().detach().numpy().tolist()}} for i in topk_slide_indices[0]]
             
-            topk_patients = [patients[i[n]] for i in topk_indices]
-            topk_patient_tiles = [topk_tiles[i[n]] for i in topk_indices]
-            for x, p, t in zip(topk, topk_patients, topk_patient_tiles):
-                print(p, x[n])
-                patient = p
-                # outpath = test_path / str(n) / patient 
-                outpath = Path(self.save_path) / str(n) / patient
-                outpath.mkdir(parents=True, exist_ok=True)
-                for i in range(len(t)):
-                    tile = t[i]
-                    tile = tile.cpu().numpy().transpose(1,2,0)
-                    tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255
-                    tile = tile.astype(np.uint8)
-                    img = Image.fromarray(tile)
+
+            
+
+            for p in topk_patients: 
+
+                score = complete_patient_dict[p]['score']
+                print(p, score)
+                print('Topk Slides:')
+                for slide in patient_top_slides[p]:
+                    print(slide)
+                    outpath = Path(self.save_path) / class_name / p / slide
+                    outpath.mkdir(parents=True, exist_ok=True)
+                
+                    topk_tiles = complete_patient_dict[p]['slides'][slide]['topk']
+                    for i in range(topk_tiles.shape[0]):
+                        tile = topk_tiles[i]
+                        tile = tile.cpu().numpy().transpose(1,2,0)
+                        tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255
+                        tile = tile.astype(np.uint8)
+                        img = Image.fromarray(tile)
                     
                     img.save(f'{outpath}/{i}.jpg')
+        output_dict['Test_Metrics'] = np_metrics
+        pp.pprint(output_dict)
+        json.dump(output_dict, open(f'{self.save_path}/test_metrics.json', 'w'))
+
+
+        
 
         for keys, values in metrics.items():
             print(f'{keys} = {values}')
@@ -174,17 +293,21 @@ class custom_test_module(ModelInterface):
             
             for tile_name in batch_names: 
                 pos = re.findall(r'\((.*?)\)', tile_name[0])
-                x, y = pos[0].split('_')
+                x, y = pos[-1].split('_')
                 coords.append((int(x),int(y)))
             return coords
         
         coords = get_coords(batch_names)
+        # coords_set = set(coords)
+
+        # print(coords)
         # temp_data = data.cpu()
         # print(data.shape)
         wsi = self.assemble(data, coords).cpu().numpy()
         # wsi = (wsi-wsi.min())/(wsi.max()-wsi.min())
         # wsi = wsi
-
+        # print(coords)
+        print('wsi.shape: ', wsi.shape)
         #--> Get interpolated mask from GradCam
         W, H = wsi.shape[0], wsi.shape[1]
         
@@ -192,10 +315,11 @@ class custom_test_module(ModelInterface):
         attention_map = grayscale_cam[:, :, 1].squeeze()
         attention_map = F.relu(attention_map)
         # print(attention_map)
-        input_h = 256
+        input_h = 224
         
         mask = torch.ones(( int(W/input_h), int(H/input_h))).to(self.device)
-
+        print('mask.shape: ', mask.shape)
+        print('attention_map.shape: ', attention_map.shape)
         for i, (x,y) in enumerate(coords):
             mask[y][x] = attention_map[i]
         mask = mask.unsqueeze(0).unsqueeze(0)
@@ -217,8 +341,11 @@ class custom_test_module(ModelInterface):
         wsi_cam = show_cam_on_image(wsi, mask)
         wsi_cam = ((wsi_cam-wsi_cam.min())/(wsi_cam.max()-wsi_cam.min()) * 255.0).astype(np.uint8)
         
+        size = (20000, 20000)
+
         img = Image.fromarray(wsi_cam)
         img = img.convert('RGB')
+        img.thumbnail(size, Image.ANTIALIAS)
         output_path = self.save_path / str(target.item())
         output_path.mkdir(parents=True, exist_ok=True)
         img.save(f'{output_path}/{wsi_name}_gradcam.jpg')
@@ -226,9 +353,14 @@ class custom_test_module(ModelInterface):
         wsi = ((wsi-wsi.min())/(wsi.max()-wsi.min()) * 255.0).astype(np.uint8)
         img = Image.fromarray(wsi)
         img = img.convert('RGB')
+        img.thumbnail(size, Image.ANTIALIAS)
         output_path = self.save_path / str(target.item())
         output_path.mkdir(parents=True, exist_ok=True)
         img.save(f'{output_path}/{wsi_name}.jpg')
+        del wsi
+        del img
+        del wsi_cam
+        del mask
 
 
     def assemble(self, tiles, coords): # with coordinates (x-y)
@@ -245,72 +377,57 @@ class custom_test_module(ModelInterface):
         count = 0
         # max_x = max(coords, key = lambda t: t[0])[0]
         d = tiles[0,:,:,:].permute(1,2,0).shape
-        print(d)
+        # print(d)
         white_value = 0
         x_max = max([x[0] for x in coords])
         y_max = max([x[1] for x in coords])
 
         for i, (x,y) in enumerate(coords):
-
-            # name = n[0]
-            # image = tiles[i,:,:,:].permute(1,2,0)
-            
-            # d = image.shape
-            # print(image.min())
-            # print(image.max())
-            # if image.max() > white_value:
-            #     white_value = image.max()
-            # # print(image.shape)
-            
-            # tile_position = '-'.join(name.split('_')[-2:])
-            # x,y = getPosition(tile_position)
-            
-            # y_max = y if y > y_max else y_max
             if x not in position_dict.keys():
                 position_dict[x] = [(y, i)]
             else: position_dict[x].append((y, i))
-            # count += 1
-        print(position_dict.keys())
-        x_positions = sorted(position_dict.keys())
-        # print(x_max)
-        # complete_image = torch.zeros([x_max, y_max, 3])
-
+        # x_positions = sorted(position_dict.keys())
 
+        test_img_compl = torch.ones([(y_max+1)*224, (x_max+1)*224, 3]).to(self.device)
         for i in range(x_max+1):
-
-            # if i in position_dict.keys():
-            #     print(i)
-            column = [None]*(int(y_max+1))
-            # if len(d) == 3:
-            # empty_tile = torch.zeros(d).to(self.device)
-            # else:
-            # empty_tile = torch.ones(d)
-            empty_tile = torch.ones(d).to(self.device)
-            # print(i)
             if i in position_dict.keys():
-                # print(i)
                 for j in position_dict[i]:
-                    print(j)
                     sample_idx = j[1]
-                    print(sample_idx)
-                    # img = tiles[sample_idx, :, :, :].permute(1,2,0)
-                    column[int(j[0])] = tiles[sample_idx, :, :, :]
-            column = [empty_tile if i is None else i for i in column]
-            print(column)
-            # for c in column:
-            #     print(c.shape)
-            # column = torch.vstack(column)
-            # print(column)
-            column = torch.stack(column)
-            assembled.append((i, column))
+                    if tiles[sample_idx, :, :, :].shape != [3,224,224]:
+                        img = tiles[sample_idx, :, :, :].permute(1,2,0)
+                    else: 
+                        img = tiles[sample_idx, :, :, :]
+                    y_coord = int(j[0])
+                    x_coord = int(i)
+                    test_img_compl[y_coord*224:(y_coord+1)*224, x_coord*224:(x_coord+1)*224, :] = img
+
+
+
+        # for i in range(x_max+1):
+        #     column = [None]*(int(y_max+1))
+        #     empty_tile = torch.ones(d).to(self.device)
+        #     if i in position_dict.keys():
+        #         for j in position_dict[i]:
+        #             sample_idx = j[1]
+        #             if tiles[sample_idx, :, :, :].shape != [3,224,224]:
+        #                 img = tiles[sample_idx, :, :, :].permute(1,2,0)
+        #             else: 
+        #                 img = tiles[sample_idx, :, :, :]
+        #             column[int(j[0])] = img
+        #     column = [empty_tile if i is None else i for i in column]
+        #     column = torch.vstack(column)
+        #     assembled.append((i, column))
         
-        assembled = sorted(assembled, key=lambda x: x[0])
+        # assembled = sorted(assembled, key=lambda x: x[0])
 
-        stack = [i[1] for i in assembled]
-        # print(stack)
-        img_compl = torch.hstack(stack)
-        print(img_compl)
-        return img_compl
+        # stack = [i[1] for i in assembled]
+        # # print(stack)
+        # img_compl = torch.hstack(stack)
+        # print(img_compl.shape)
+        # print(test_img_compl)
+        # print(torch.nonzero(img_compl - test_img_compl))
+        # print(img_compl)
+        return test_img_compl.cpu().detach()
 
 
 #---->main
@@ -331,8 +448,11 @@ def main(cfg):
     # cfg.callbacks = load_callbacks(cfg, save_path)
 
     home = Path.cwd().parts[1]
-    cfg.Data.label_file = '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
-    cfg.Data.data_dir = '/home/ylan/data/DeepGraft/224_128um/'
+    # cfg.Data.label_file = '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_Utrecht.json'
+    # cfg.Data.label_file = '/homeStor1/ylan/DeepGraft/training_tables/split_debug.json'
+    # cfg.Data.label_file = '/home/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_norm_rest.json'
+    # cfg.Data.patient_slide = '/homeStor1/ylan/DeepGraft/training_tables/cohort_stain_dict.json'
+    # cfg.Data.data_dir = '/homeStor1/ylan/data/DeepGraft/224_128um_v2/'
     DataInterface_dict = {
                 'data_root': cfg.Data.data_dir,
                 'label_path': cfg.Data.label_file,
@@ -353,15 +473,19 @@ def main(cfg):
                             'data': cfg.Data,
                             'log': cfg.log_path,
                             'backbone': cfg.Model.backbone,
+                            'task': cfg.task,
                             }
     # model = ModelInterface(**ModelInterface_dict)
     model = custom_test_module(**ModelInterface_dict)
+    # model._fc1 = nn.Sequential(nn.Linear(512, 512), nn.GELU())
     # model.save_path = cfg.log_path
     #---->Instantiate Trainer
     
+    tb_logger = pl_loggers.TensorBoardLogger(cfg.log_path)
+
     trainer = Trainer(
         num_sanity_val_steps=0, 
-        # logger=cfg.load_loggers,
+        logger=tb_logger,
         # callbacks=cfg.callbacks,
         max_epochs= cfg.General.epochs,
         min_epochs = 200,
@@ -384,6 +508,7 @@ def main(cfg):
     # log_path = Path('lightning_logs/2/checkpoints')
     model_paths = list(log_path.glob('*.ckpt'))
 
+
     if cfg.epoch == 'last':
         model_paths = [str(model_path) for model_path in model_paths if f'last' in str(model_path)]
     else:
@@ -404,7 +529,27 @@ def main(cfg):
     
     # Top 5 scoring patches for patient
     # GradCam
+def check_home(cfg):
+    # replace home directory
+    
+    home = Path.cwd().parts[1]
 
+    x = cfg.General.log_path
+    if Path(x).parts[1] != home:
+        new_path = Path(home).joinpath(*Path(x).parts[2:])
+        cfg.General.log_path = '/' + str(new_path)
+
+    x = cfg.Data.data_dir
+    if Path(x).parts[1] != home:
+        new_path = Path(home).joinpath(*Path(x).parts[2:])
+        cfg.Data.data_dir = '/' + str(new_path)
+        
+    x = cfg.Data.label_file
+    if Path(x).parts[1] != home:
+        new_path = Path(home).joinpath(*Path(x).parts[2:])
+        cfg.Data.label_file = '/' + str(new_path)
+
+    return cfg
 
 if __name__ == '__main__':
 
@@ -421,6 +566,8 @@ if __name__ == '__main__':
     cfg.version = args.version
     cfg.epoch = args.epoch
 
+    cfg = check_home(cfg)
+
     config_path = '/'.join(Path(cfg.config).parts[1:])
     log_path = Path(cfg.General.log_path) / str(Path(config_path).parent)
 
@@ -428,8 +575,11 @@ if __name__ == '__main__':
     log_name =  f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
     task = '_'.join(Path(cfg.config).name[:-5].split('_')[2:])
     # task = Path(cfg.config).name[:-5].split('_')[2:][0]
+    cfg.task = task
     cfg.log_path = log_path / f'{cfg.Model.name}' / task / log_name / 'lightning_logs' / f'version_{cfg.version}' 
     
+    # cfg.model_path = cfg.log_patth / 'code' / 'models' / 
+    
     
 
     #---->main
diff --git a/code/train.py b/code/train.py
index ddd37e443ffee175cdb37b0e2d2e6728728dc5ba..e01bc52c17b0b8b0698124bb4ae0e61d087a585d 100644
--- a/code/train.py
+++ b/code/train.py
@@ -14,8 +14,48 @@ from utils.utils import *
 # pytorch_lightning
 import pytorch_lightning as pl
 from pytorch_lightning import Trainer
+from pytorch_lightning.strategies import DDPStrategy
 import torch
 from train_loop import KFoldLoop
+from pytorch_lightning.plugins.training_type import DDPPlugin
+
+
+try:
+    import apex
+    from apex.parallel import DistributedDataParallel
+    print('Apex available.')
+except ModuleNotFoundError:
+    # Error handling
+    pass
+
+def unwrap_lightning_module(wrapped_model):
+    from apex.parallel import DistributedDataParallel
+    from pytorch_lightning.overrides.base import (
+        _LightningModuleWrapperBase,
+        _LightningPrecisionModuleWrapperBase,
+    )
+
+    model = wrapped_model
+    if isinstance(model, DistributedDataParallel):
+        model = unwrap_lightning_module(model.module)
+    if isinstance(
+        model, (_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase)
+    ):
+        model = unwrap_lightning_module(model.module)
+    return model
+
+
+class ApexDDPPlugin(DDPPlugin):
+    def _setup_model(self, model):
+        from apex.parallel import DistributedDataParallel
+
+        return DistributedDataParallel(model, delay_allreduce=False)
+
+    @property
+    def lightning_module(self):
+        return unwrap_lightning_module(self._model)
+
+
 
 #--->Setting parameters
 def make_parse():
@@ -23,11 +63,12 @@ def make_parse():
     parser.add_argument('--stage', default='train', type=str)
     parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str)
     parser.add_argument('--version', default=2,type=int)
-    parser.add_argument('--gpus', default = 2, type=int)
+    parser.add_argument('--gpus', nargs='+', default = [2], type=int)
     parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str)
     parser.add_argument('--fold', default = 0)
     parser.add_argument('--bag_size', default = 1024, type=int)
     parser.add_argument('--resume_training', action='store_true')
+    parser.add_argument('--label_file', type=str)
     # parser.add_argument('--ckpt_path', default = , type=str)
     
 
@@ -59,6 +100,9 @@ def main(cfg):
     #             'dataset_cfg': cfg.Data,}
     # dm = DataInterface(**DataInterface_dict)
     home = Path.cwd().parts[1]
+    if cfg.Model.backbone == 'features':
+        use_features = True
+    else: use_features = False
     DataInterface_dict = {
                 'data_root': cfg.Data.data_dir,
                 'label_path': cfg.Data.label_file,
@@ -66,6 +110,7 @@ def main(cfg):
                 'num_workers': cfg.Data.train_dataloader.num_workers,
                 'n_classes': cfg.Model.n_classes,
                 'bag_size': cfg.Data.bag_size,
+                'use_features': use_features,
                 }
 
     if cfg.Data.cross_val:
@@ -80,6 +125,7 @@ def main(cfg):
                             'data': cfg.Data,
                             'log': cfg.log_path,
                             'backbone': cfg.Model.backbone,
+                            'task': cfg.task,
                             }
     if cfg.Model.name == 'DTFDMIL':
         model = ModelInterface_DTFD(**ModelInterface_dict)
@@ -87,33 +133,64 @@ def main(cfg):
         model = ModelInterface(**ModelInterface_dict)
     
     #---->Instantiate Trainer
-    trainer = Trainer(
-        # num_sanity_val_steps=0, 
-        logger=cfg.load_loggers,
-        callbacks=cfg.callbacks,
-        max_epochs= cfg.General.epochs,
-        min_epochs = 200,
-        gpus=cfg.General.gpus,
-        # gpus = [0,2],
-        # strategy='ddp',
-        amp_backend='native',
-        # amp_level=cfg.General.amp_level,  
-        precision=cfg.General.precision,  
-        accumulate_grad_batches=cfg.General.grad_acc,
-        gradient_clip_val=0.0,
-        # fast_dev_run = True,
-        # limit_train_batches=1,
-        
-        # deterministic=True,
-        check_val_every_n_epoch=5,
-    )
+    # plugins = []
+    # if apex: 
+    #     plugins.append(ApexDDPPlugin())
+
+    if len(cfg.General.gpus) > 1:
+        trainer = Trainer(
+            logger=cfg.load_loggers,
+            callbacks=cfg.callbacks,
+            max_epochs= cfg.General.epochs,
+            min_epochs = 100,
+            accelerator='gpu',
+            # plugins=plugins,
+            devices=cfg.General.gpus,
+            strategy=DDPStrategy(find_unused_parameters=False),
+            replace_sampler_ddp=False,
+            amp_backend='native',
+            precision=cfg.General.precision,  
+            # accumulate_grad_batches=cfg.General.grad_acc,
+            gradient_clip_val=0.0,
+            # fast_dev_run = True,
+            # limit_train_batches=1,
+            
+            # deterministic=True,
+            check_val_every_n_epoch=5,
+        )
+    else:
+        trainer = Trainer(
+            # num_sanity_val_steps=0, 
+            logger=cfg.load_loggers,
+            callbacks=cfg.callbacks,
+            max_epochs= cfg.General.epochs,
+            min_epochs = 100,
+
+            # gpus=cfg.General.gpus,
+            accelerator='gpu'
+            devices=cfg.General.gpus,
+            amp_backend='native',
+            # amp_level=cfg.General.amp_level,  
+            precision=cfg.General.precision,  
+            accumulate_grad_batches=cfg.General.grad_acc,
+            gradient_clip_val=0.0,
+            # fast_dev_run = True,
+            # limit_train_batches=1,
+            
+            # deterministic=True,
+            check_val_every_n_epoch=5,
+        )
     # print(cfg.log_path)
     # print(trainer.loggers[0].log_dir)
     # print(trainer.loggers[1].log_dir)
     #----> Copy Code
+
+    # home = Path.cwd()[0]
+
     copy_path = Path(trainer.loggers[0].log_dir) / 'code'
     copy_path.mkdir(parents=True, exist_ok=True)
     copy_origin = '/' / Path('/'.join(cfg.log_path.parts[1:5])) / 'code'
+    # print(copy_path)
     # print(copy_origin)
     shutil.copytree(copy_origin, copy_path, dirs_exist_ok=True)
 
@@ -123,7 +200,9 @@ def main(cfg):
     #---->train or test
     if cfg.resume_training:
         last_ckpt = log_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.version}' / 'last.ckpt'
-        trainer.fit(model = model, datamodule = dm, ckpt_path=last_ckpt)
+        # model = model.load_from_checkpoint(last_ckpt)
+        # trainer.fit(model, dm) #, datamodule = dm
+        trainer.fit(model = model, ckpt_path=last_ckpt) #, datamodule = dm
 
     if cfg.General.server == 'train':
 
@@ -152,6 +231,29 @@ def main(cfg):
             trainer.test(model=new_model, datamodule=dm)
 
 
+def check_home(cfg):
+    # replace home directory
+    
+    home = Path.cwd().parts[1]
+
+    x = cfg.General.log_path
+    if Path(x).parts[1] != home:
+        new_path = Path(home).joinpath(*Path(x).parts[2:])
+        cfg.General.log_path = '/' + str(new_path)
+
+    x = cfg.Data.data_dir
+    if Path(x).parts[1] != home:
+        new_path = Path(home).joinpath(*Path(x).parts[2:])
+        cfg.Data.data_dir = '/' + str(new_path)
+        
+    x = cfg.Data.label_file
+    if Path(x).parts[1] != home:
+        new_path = Path(home).joinpath(*Path(x).parts[2:])
+        cfg.Data.label_file = '/' + str(new_path)
+
+    return cfg
+
+
 if __name__ == '__main__':
 
     args = make_parse()
@@ -159,12 +261,16 @@ if __name__ == '__main__':
 
     #---->update
     cfg.config = args.config
-    cfg.General.gpus = [args.gpus]
+    cfg.General.gpus = args.gpus
     cfg.General.server = args.stage
     cfg.Data.fold = args.fold
     cfg.Loss.base_loss = args.loss
     cfg.Data.bag_size = args.bag_size
     cfg.version = args.version
+    if args.label_file: 
+        cfg.Data.label_file = '/home/ylan/DeepGraft/training_tables/' + args.label_file
+
+    cfg = check_home(cfg)
 
     config_path = '/'.join(Path(cfg.config).parts[1:])
     log_path = Path(cfg.General.log_path) / str(Path(config_path).parent)
@@ -174,12 +280,12 @@ if __name__ == '__main__':
     Path(cfg.General.log_path).mkdir(exist_ok=True, parents=True)
     log_name =  f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
     task = '_'.join(Path(cfg.config).name[:-5].split('_')[2:])
+    cfg.task = task
     # task = Path(cfg.config).name[:-5].split('_')[2:][0]
     cfg.log_path = log_path / f'{cfg.Model.name}' / task / log_name 
     
     
-    
 
-    #---->main
+    # ---->main
     main(cfg)
  
\ No newline at end of file
diff --git a/code/utils/__init__.py b/code/utils/__init__.py
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..5caecf700da5e467edb71103e550ca714f21f962 100644
--- a/code/utils/__init__.py
+++ b/code/utils/__init__.py
@@ -0,0 +1 @@
+from .custom_resnet50 import resnet50_baseline
\ No newline at end of file
diff --git a/code/utils/__pycache__/__init__.cpython-39.pyc b/code/utils/__pycache__/__init__.cpython-39.pyc
index 3dcd009c36ba1c22656489be43171004bc84df85..dcda699d2566f65d9c29aac8a8f8bc255eb8cf7e 100644
Binary files a/code/utils/__pycache__/__init__.cpython-39.pyc and b/code/utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/code/utils/__pycache__/custom_resnet50.cpython-39.pyc b/code/utils/__pycache__/custom_resnet50.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..84b461b38eae431d5b26a7cbd8e846cb71977b4b
Binary files /dev/null and b/code/utils/__pycache__/custom_resnet50.cpython-39.pyc differ
diff --git a/code/utils/__pycache__/utils.cpython-39.pyc b/code/utils/__pycache__/utils.cpython-39.pyc
index 4e86955d7bdf2df00f04cb9e61d452bc3f04da04..26de104113838d7e9f56ecbe311a3c8813d08d77 100644
Binary files a/code/utils/__pycache__/utils.cpython-39.pyc and b/code/utils/__pycache__/utils.cpython-39.pyc differ
diff --git a/code/utils/custom_resnet50.py b/code/utils/custom_resnet50.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c3f33518bc53a1611b94767bc50fbbb169154d0
--- /dev/null
+++ b/code/utils/custom_resnet50.py
@@ -0,0 +1,122 @@
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+import torch
+import torch.nn.functional as F
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+           'resnet152']
+
+model_urls = {
+    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+}
+
+class Bottleneck_Baseline(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck_Baseline, self).__init__()
+        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+                               padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+class ResNet_Baseline(nn.Module):
+
+    def __init__(self, block, layers):
+        self.inplanes = 64
+        super(ResNet_Baseline, self).__init__()
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+                               bias=False)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.avgpool = nn.AdaptiveAvgPool2d(1) 
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2d(self.inplanes, planes * block.expansion,
+                          kernel_size=1, stride=stride, bias=False),
+                nn.BatchNorm2d(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+
+        return x
+
+def resnet50_baseline(pretrained=False):
+    """Constructs a Modified ResNet-50 model.
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet_Baseline(Bottleneck_Baseline, [3, 4, 6, 3])
+    if pretrained:
+        model = load_pretrained_weights(model, 'resnet50')
+    return model
+
+def load_pretrained_weights(model, name):
+    pretrained_dict = model_zoo.load_url(model_urls[name])
+    model.load_state_dict(pretrained_dict, strict=False)
+    return model
diff --git a/code/utils/extract_features.py b/code/utils/extract_features.py
index 9ec600e3cf5ee46cd48d57d55173c01ce8bbc7fd..daae76954509eb4d6bdfa45891e26bf23f9ad614 100644
--- a/code/utils/extract_features.py
+++ b/code/utils/extract_features.py
@@ -19,6 +19,38 @@ def extract_features(input_dir, output_dir, model, batch_size):
     model = model.to(device)
     model.eval()
 
+
+
+    for bag_candidate_idx in range(total):
+        bag_candidate = bags_dataset[bag_candidate_idx]
+        bag_name = os.path.basename(os.path.normpath(bag_candidate))
+        print(bag_name)
+        print('\nprogress: {}/{}'.format(bag_candidate_idx, total))
+        bag_base = bag_name.split('\\')[-1]
+        
+        if not os.path.exists(os.path.join(feat_dir, bag_base + '.pt')):
+            
+            print(bag_name)
+            
+            output_path = os.path.join(feat_dir, bag_name)
+            file_path = bag_candidate
+            print(file_path)
+            output_file_path = Compute_w_loader(file_path, output_path, 
+    												model = model, batch_size = batch_size, 
+    												verbose = 1, print_every = 20,
+    												target_patch_size = target_patch_size)
+                        
+            if os.path.exists (output_file_path):
+                file = h5py.File(output_file_path, "r")
+                features = file['features'][:]
+                
+                print('features size: ', features.shape)
+                print('coordinates size: ', file['coords'].shape)
+                
+                features = torch.from_numpy(features)
+                torch.save(features, os.path.join(feat_dir, bag_base+'.pt'))
+                file.close()
+
     
 
 
diff --git a/code/utils/utils.py b/code/utils/utils.py
index 7e992cffd810cd7cf54babc70477e8dc174d82e3..5596208e41716da17ac42e18741e09b99ff93667 100755
--- a/code/utils/utils.py
+++ b/code/utils/utils.py
@@ -65,7 +65,7 @@ def load_loggers(cfg):
 
 
 #---->load Callback
-from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
+from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar, StochasticWeightAveraging
 from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
 from pytorch_lightning.callbacks.early_stopping import EarlyStopping
 
@@ -77,11 +77,11 @@ def load_callbacks(cfg, save_path):
     output_path.mkdir(exist_ok=True, parents=True)
 
     early_stop_callback = EarlyStopping(
-        monitor='val_loss',
+        monitor='val_auc',
         min_delta=0.00,
         patience=cfg.General.patience,
         verbose=True,
-        mode='min'
+        mode='max'
     )
 
     Mycallbacks.append(early_stop_callback)
@@ -106,7 +106,7 @@ def load_callbacks(cfg, save_path):
                                          filename = '{epoch:02d}-{val_loss:.4f}-{val_auc: .4f}',
                                          verbose = True,
                                          save_last = True,
-                                         save_top_k = 1,
+                                         save_top_k = 2,
                                          mode = 'min',
                                          save_weights_only = True))
         Mycallbacks.append(ModelCheckpoint(monitor = 'val_auc',
@@ -114,9 +114,13 @@ def load_callbacks(cfg, save_path):
                                          filename = '{epoch:02d}-{val_loss:.4f}-{val_auc:.4f}',
                                          verbose = True,
                                          save_last = True,
-                                         save_top_k = 1,
+                                         save_top_k = 2,
                                          mode = 'max',
                                          save_weights_only = True))
+    
+    swa = StochasticWeightAveraging(swa_lrs=1e-2)
+    Mycallbacks.append(swa)
+
     return Mycallbacks
 
 #---->val loss