diff --git a/DeepGraft/AttMIL_feat_norm_rej_rest.yaml b/DeepGraft/AttMIL_feat_norm_rej_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4eda25f361c38a6dd9bc91154bc6c2bd740c2996
--- /dev/null
+++ b/DeepGraft/AttMIL_feat_norm_rej_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_norm_rej_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: AttMIL
+    n_classes: 3
+    backbone: features
+    in_features: 1024
+    out_features: 512
+
+
+Optimizer:
+    opt: Adam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/AttMIL_feat_norm_rest.yaml b/DeepGraft/AttMIL_feat_norm_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fefbe5ed47b398f6c45cfb3004f360dd95f0bdb6
--- /dev/null
+++ b/DeepGraft/AttMIL_feat_norm_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 1000 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 100
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_norm_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: AttMIL
+    n_classes: 2
+    backbone: features
+    in_features: 512
+    out_features: 1024
+
+
+Optimizer:
+    opt: Adam
+    lr: 0.0001
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/AttMIL_feat_rej_rest.yaml b/DeepGraft/AttMIL_feat_rej_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3d854bede1e1e6841222d2513fa5171678ef7c94
--- /dev/null
+++ b/DeepGraft/AttMIL_feat_rej_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_rej_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: AttMIL
+    n_classes: 2
+    backbone: features
+    in_features: 1024
+    out_features: 512
+
+
+Optimizer:
+    opt: Adam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/AttTrans_resnet50_norm_rest.yaml b/DeepGraft/AttTrans_resnet50_norm_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cfffa7d9a5e13b2eae17b2c8d1636e2112807dd6
--- /dev/null
+++ b/DeepGraft/AttTrans_resnet50_norm_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 32
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 1000 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 100
+    server: train #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128uM_annotated/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_limit_100_split_PAS_HE_Jones_norm_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: AttTrans
+    n_classes: 2
+    backbone: resnet50
+    in_features: 512
+    out_features: 1024
+
+
+Optimizer:
+    opt: Adam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransMIL_feat_norm_rej_rest.yaml b/DeepGraft/TransMIL_feat_norm_rej_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3dcc81775c0741d3f9cf02b8b501349bb70323c8
--- /dev/null
+++ b/DeepGraft/TransMIL_feat_norm_rej_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 1000 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 100
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rej_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: TransMIL
+    n_classes: 3
+    backbone: features
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransMIL_feat_norm_rest.yaml b/DeepGraft/TransMIL_feat_norm_rest.yaml
index dab6568ee8f7895affa58eec70bacc0f8adccfef..ea452a34b4d1a4ba83aa669486def30b9e5bb12e 100644
--- a/DeepGraft/TransMIL_feat_norm_rest.yaml
+++ b/DeepGraft/TransMIL_feat_norm_rest.yaml
@@ -3,20 +3,22 @@ General:
     seed: 2021
     fp16: True
     amp_level: O2
-    precision: 16 
+    precision: 16
     multi_gpu_mode: dp
-    gpus: [0]
+    gpus: [0, 1]
     epochs: &epoch 500 
     grad_acc: 2
     frozen_bn: False
     patience: 50
-    server: test #train #test
+    server: train #train #test
     log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
 
 Data:
     dataset_name: custom
     data_shuffle: False
-    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    mixup: False
+    aug: True
+    data_dir: '/home/ylan/data/DeepGraft/224_128uM_annotated/'
     label_file: '/home/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest.json'
     fold: 1
     nfold: 3
@@ -34,12 +36,12 @@ Model:
     name: TransMIL
     n_classes: 2
     backbone: features
-    in_features: 512
+    in_features: 2048
     out_features: 1024
 
 
 Optimizer:
-    opt: lookahead_radam
+    opt: Adam
     lr: 0.0001
     opt_eps: null 
     opt_betas: null
diff --git a/DeepGraft/TransMIL_feat_rej_rest.yaml b/DeepGraft/TransMIL_feat_rej_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ca9c0e47e15a2588a2fee91529653f2a7c07c735
--- /dev/null
+++ b/DeepGraft/TransMIL_feat_rej_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_rej_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 2
+    backbone: features
+    in_features: 1024
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransMIL_feat_rejections.yaml b/DeepGraft/TransMIL_feat_rejections.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a1d2ae6d81b8830c0591462b83e346e7d338073b
--- /dev/null
+++ b/DeepGraft/TransMIL_feat_rejections.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_HE_Jones_rejections.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 3
+    backbone: features
+    in_features: 1024
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/TransMIL_resnet50_norm_rej_rest.yaml b/DeepGraft/TransMIL_resnet50_norm_rej_rest.yaml
index 98857434cac5d09f25c7e68e05e41abf0c16c325..1fa23b5c999e9819fc6c4de07ce4c2f94244e48b 100644
--- a/DeepGraft/TransMIL_resnet50_norm_rej_rest.yaml
+++ b/DeepGraft/TransMIL_resnet50_norm_rej_rest.yaml
@@ -6,11 +6,11 @@ General:
     precision: 16 
     multi_gpu_mode: dp
     gpus: [0]
-    epochs: &epoch 500 
+    epochs: &epoch 1000 
     grad_acc: 2
     frozen_bn: False
-    patience: 50
-    server: test #train #test
+    patience: 100
+    server: train #train #test
     log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
 
 Data:
diff --git a/DeepGraft/TransMIL_resnet50_norm_rest.yaml b/DeepGraft/TransMIL_resnet50_norm_rest.yaml
index b58c6d880b5707fcdd2aa39e65ae2fb77871bfcd..0511d268534c6e2a3af203a3663acc13dab7b486 100644
--- a/DeepGraft/TransMIL_resnet50_norm_rest.yaml
+++ b/DeepGraft/TransMIL_resnet50_norm_rest.yaml
@@ -17,7 +17,7 @@ Data:
     dataset_name: custom
     data_shuffle: False
     data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
-    label_file: '/home/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_norm_rest.json'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_limit_100_split_PAS_HE_Jones_norm_rest_RA_RU.json'
     fold: 1
     nfold: 3
     cross_val: False
diff --git a/DeepGraft/TransformerMIL_feat_norm_rest.yaml b/DeepGraft/TransformerMIL_feat_norm_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7c90fbff0e56441623ff2ae68cf791776cf7d589
--- /dev/null
+++ b/DeepGraft/TransformerMIL_feat_norm_rest.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 1000 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 100
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: TransformerMIL
+    n_classes: 2
+    backbone: features
+    in_features: 512
+    out_features: 1024
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0001
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft_Project_Plan.pdf b/DeepGraft_Project_Plan.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..1b5a0c95fca1a588080f74941d9141b5b45bb604
Binary files /dev/null and b/DeepGraft_Project_Plan.pdf differ
diff --git a/README.md b/README.md
index c8594bc7edf8bcade5f8d1c9749978fd0c71f98e..810ff74e5cba1be3bb7ccb987f7f6208100fbda5 100644
--- a/README.md
+++ b/README.md
@@ -31,8 +31,7 @@ wd = 0.01
 
 ### Ablation
 
-image drop out: 
-tcmr_viral TCMR efficientnet: version 0
+# Important things: 
+
+    * 
 
-wd incerease: 
-tcmr_viral TCMR efficientnet: version 110
\ No newline at end of file
diff --git a/code/MyLoss/__pycache__/loss_factory.cpython-39.pyc b/code/MyLoss/__pycache__/loss_factory.cpython-39.pyc
index db4ebcceee99d56aa365237e82f0e543b7ba5cc1..5d4cc8bb3818896e5944e5eb2f6b551388b67e48 100644
Binary files a/code/MyLoss/__pycache__/loss_factory.cpython-39.pyc and b/code/MyLoss/__pycache__/loss_factory.cpython-39.pyc differ
diff --git a/code/MyLoss/loss_factory.py b/code/MyLoss/loss_factory.py
index 61502bf66fda9c2d2f4ebb8df1566a1d56583b5e..f17f69804e81667ee796e7e509803ea5735829e4 100755
--- a/code/MyLoss/loss_factory.py
+++ b/code/MyLoss/loss_factory.py
@@ -29,7 +29,7 @@ def create_loss(args, n_classes, w1=1.0, w2=0.5):
     loss = None
     print(conf_loss)
     if hasattr(nn, conf_loss): 
-        loss = getattr(nn, conf_loss)(label_smoothing=0.5) 
+        loss = getattr(nn, conf_loss)()
         # loss = getattr(nn, conf_loss)(label_smoothing=0.5) 
     #binary loss
     elif conf_loss == "focal":
diff --git a/code/__pycache__/test_visualize.cpython-39.pyc b/code/__pycache__/test_visualize.cpython-39.pyc
index e961ce76be403bfa75de922eeedcbaa056c2b97c..c3a94d6f97af78fa95e1263a76b8e60928a923d5 100644
Binary files a/code/__pycache__/test_visualize.cpython-39.pyc and b/code/__pycache__/test_visualize.cpython-39.pyc differ
diff --git a/code/datasets/ResNet.py b/code/datasets/ResNet.py
index f8fe70648f55b58eacc5da0578821d7842c563f3..5c3c2776f419a4b9ccc31e101f9ee442b6540b99 100644
--- a/code/datasets/ResNet.py
+++ b/code/datasets/ResNet.py
@@ -394,4 +394,12 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
     """
     kwargs['width_per_group'] = 64 * 2
     return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
-                   pretrained, progress, **kwargs)
\ No newline at end of file
+                   pretrained, progress, **kwargs)
+
+if __name__ == '__main__':
+
+
+    model_ft = resnet50(num_classes=1024, mlp=False, two_branch=False, normlinear=True)
+
+    # model_ft.fc = nn.Identity()
+    print(model_ft)
\ No newline at end of file
diff --git a/code/datasets/__pycache__/ResNet.cpython-39.pyc b/code/datasets/__pycache__/ResNet.cpython-39.pyc
index 3d963e71935c049d798681b7b63ea581b34f4472..6abf8d43e83aa27f6d60f4b306f2509936d359c5 100644
Binary files a/code/datasets/__pycache__/ResNet.cpython-39.pyc and b/code/datasets/__pycache__/ResNet.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc b/code/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc
index 9f536c5dcea366ca8fa4c041f92b33ca5546e2a4..549f6fa590ac3a8baf0494d6d9c5c7431c562d8f 100644
Binary files a/code/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc and b/code/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/data_interface.cpython-39.pyc b/code/datasets/__pycache__/data_interface.cpython-39.pyc
index 2e97390e89b0ccb4905df819317f14850af46697..e1151f291a6f1bd56821ff01a34f3019adc21b35 100644
Binary files a/code/datasets/__pycache__/data_interface.cpython-39.pyc and b/code/datasets/__pycache__/data_interface.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/feature_dataloader.cpython-39.pyc b/code/datasets/__pycache__/feature_dataloader.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..10319c23951eae75c10e1a3f2836050359059030
Binary files /dev/null and b/code/datasets/__pycache__/feature_dataloader.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/simple_jpg_dataloader.cpython-39.pyc b/code/datasets/__pycache__/simple_jpg_dataloader.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..782f8d99cad131fbd9c433fc963873773108d79b
Binary files /dev/null and b/code/datasets/__pycache__/simple_jpg_dataloader.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/zarr_feature_dataloader.cpython-39.pyc b/code/datasets/__pycache__/zarr_feature_dataloader.cpython-39.pyc
index 6316a91dcd30db8d68e2af702b444f46463713d2..ae34e4af9a83aa088b82ae14f314d71bcb54fcf0 100644
Binary files a/code/datasets/__pycache__/zarr_feature_dataloader.cpython-39.pyc and b/code/datasets/__pycache__/zarr_feature_dataloader.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/zarr_feature_dataloader_simple.cpython-39.pyc b/code/datasets/__pycache__/zarr_feature_dataloader_simple.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..94974784ee317472ad2486c5810ae4bdd346192a
Binary files /dev/null and b/code/datasets/__pycache__/zarr_feature_dataloader_simple.cpython-39.pyc differ
diff --git a/code/datasets/custom_jpg_dataloader.py b/code/datasets/custom_jpg_dataloader.py
index c28acc1a378e95ad770a045e583f6ee0011e2181..95d5adb331955464376b8361c38a2e1f65d76be1 100644
--- a/code/datasets/custom_jpg_dataloader.py
+++ b/code/datasets/custom_jpg_dataloader.py
@@ -212,7 +212,7 @@ class JPGMILDataloader(data.Dataset):
         #     print(out_batch.shape)
         # out_batch = torch.permute(out_batch, (0, 2,1,3))
         label = torch.as_tensor(label)
-        label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+        # label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
         # print(out_batch)
         return out_batch, label, (name, batch_names, patient) #, name_batch
 
diff --git a/code/datasets/data_interface.py b/code/datasets/data_interface.py
index fc4acabfae48ace4abb9833a2b67cbbecde6f2e8..7049b225637bf349978a69ac33a9f2743f97c8bf 100644
--- a/code/datasets/data_interface.py
+++ b/code/datasets/data_interface.py
@@ -6,121 +6,125 @@ import pytorch_lightning as pl
 # from pytorch_lightning.loops.fit_loop import FitLoop
 
 from torch.utils.data import random_split, DataLoader
+from torch.utils.data.sampler import WeightedRandomSampler
 from torch.utils.data.dataset import Dataset, Subset
 from torchvision.datasets import MNIST
 from torchvision import transforms
-from .camel_dataloader import FeatureBagLoader
+# from .camel_dataloader import FeatureBagLoader
 from .custom_dataloader import HDF5MILDataloader
-from .custom_jpg_dataloader import JPGMILDataloader
-from .zarr_feature_dataloader import ZarrFeatureBagLoader
+# from .custom_jpg_dataloader import JPGMILDataloader
+from .simple_jpg_dataloader import JPGBagLoader
+from .zarr_feature_dataloader_simple import ZarrFeatureBagLoader
+from .feature_dataloader import FeatureBagLoader
 from pathlib import Path
 # from transformers import AutoFeatureExtractor
 from torchsampler import ImbalancedDatasetSampler
 
 from abc import ABC, abstractclassmethod, abstractmethod
 from sklearn.model_selection import KFold
+import numpy as np
+import torch
 
 
+# class DataInterface(pl.LightningDataModule):
 
-class DataInterface(pl.LightningDataModule):
+#     def __init__(self, train_batch_size=64, train_num_workers=8, test_batch_size=1, test_num_workers=1,dataset_name=None, **kwargs):
+#         """[summary]
 
-    def __init__(self, train_batch_size=64, train_num_workers=8, test_batch_size=1, test_num_workers=1,dataset_name=None, **kwargs):
-        """[summary]
+#         Args:
+#             batch_size (int, optional): [description]. Defaults to 64.
+#             num_workers (int, optional): [description]. Defaults to 8.
+#             dataset_name (str, optional): [description]. Defaults to ''.
+#         """        
+#         super().__init__()
 
-        Args:
-            batch_size (int, optional): [description]. Defaults to 64.
-            num_workers (int, optional): [description]. Defaults to 8.
-            dataset_name (str, optional): [description]. Defaults to ''.
-        """        
-        super().__init__()
-
-        self.train_batch_size = train_batch_size
-        self.train_num_workers = train_num_workers
-        self.test_batch_size = test_batch_size
-        self.test_num_workers = test_num_workers
-        self.dataset_name = dataset_name
-        self.kwargs = kwargs
-        self.load_data_module()
-        home = Path.cwd().parts[1]
-        self.data_root = f'/{home}/ylan/RCC_project/rcc_classification/datasets/Camelyon16/Camelyon16.csv'
+#         self.train_batch_size = train_batch_size
+#         self.train_num_workers = train_num_workers
+#         self.test_batch_size = test_batch_size
+#         self.test_num_workers = test_num_workers
+#         self.dataset_name = dataset_name
+#         self.kwargs = kwargs
+#         self.load_data_module()
+#         home = Path.cwd().parts[1]
+#         self.data_root = f'/{home}/ylan/RCC_project/rcc_classification/datasets/Camelyon16/Camelyon16.csv'
 
  
 
-    def prepare_data(self):
-        # 1. how to download
-        # MNIST(self.data_dir, train=True, download=True)
-        # MNIST(self.data_dir, train=False, download=True)
-        ...
-
-    def setup(self, stage=None):
-        # 2. how to split, argument
-        """  
-        - count number of classes
-
-        - build vocabulary
-
-        - perform train/val/test splits
-
-        - apply transforms (defined explicitly in your datamodule or assigned in init)
-        """
-        # Assign train/val datasets for use in dataloaders
-        if stage == 'fit' or stage is None:
-            dataset = FeatureBagLoader(data_root = self.data_root,
-                                                train=True)
-            a = int(len(dataset)* 0.8)
-            b = int(len(dataset) - a)
-            # print(a)
-            # print(b)
-            self.train_dataset, self.val_dataset = random_split(dataset, [a, b]) # returns data.Subset
-
-            # self.train_dataset = self.instancialize(state='train')
-            # self.val_dataset = self.instancialize(state='val')
+#     def prepare_data(self):
+#         # 1. how to download
+#         # MNIST(self.data_dir, train=True, download=True)
+#         # MNIST(self.data_dir, train=False, download=True)
+#         ...
+
+#     def setup(self, stage=None):
+#         # 2. how to split, argument
+#         """  
+#         - count number of classes
+
+#         - build vocabulary
+
+#         - perform train/val/test splits
+
+#         - apply transforms (defined explicitly in your datamodule or assigned in init)
+#         """
+#         # Assign train/val datasets for use in dataloaders
+#         if stage == 'fit' or stage is None:
+#             dataset = FeatureBagLoader(data_root = self.data_root,
+#                                                 train=True)
+#             a = int(len(dataset)* 0.8)
+#             b = int(len(dataset) - a)
+#             # print(a)
+#             # print(b)
+#             self.train_dataset, self.val_dataset = random_split(dataset, [a, b]) # returns data.Subset
+
+#             # self.train_dataset = self.instancialize(state='train')
+#             # self.val_dataset = self.instancialize(state='val')
  
 
-        # Assign test dataset for use in dataloader(s)
-        if stage == 'test' or stage is None:
-            # self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
-            self.test_dataset = FeatureBagLoader(data_root = self.data_root,
-                                                train=False)
-            # self.test_dataset = self.instancialize(state='test')
+#         # Assign test dataset for use in dataloader(s)
+#         if stage == 'test' or stage is None:
+#             # self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
+#             self.test_dataset = FeatureBagLoader(data_root = self.data_root,
+#                                                 train=False)
+#             # self.test_dataset = self.instancialize(state='test')
 
 
-    def train_dataloader(self):
-        return DataLoader(self.train_dataset, batch_size=self.train_batch_size, num_workers=self.train_num_workers, shuffle=False)
+#     def train_dataloader(self):
+#         return DataLoader(self.train_dataset, batch_size=self.train_batch_size, num_workers=self.train_num_workers, shuffle=False)
 
-    def val_dataloader(self):
-        return DataLoader(self.val_dataset, batch_size=self.train_batch_size, num_workers=self.train_num_workers, shuffle=False)
+#     def val_dataloader(self):
+#         return DataLoader(self.val_dataset, batch_size=self.train_batch_size, num_workers=self.train_num_workers, shuffle=False)
 
-    def test_dataloader(self):
-        return DataLoader(self.test_dataset, batch_size=self.test_batch_size, num_workers=self.test_num_workers, shuffle=False)
+#     def test_dataloader(self):
+#         return DataLoader(self.test_dataset, batch_size=self.test_batch_size, num_workers=self.test_num_workers, shuffle=False)
 
 
-    def load_data_module(self):
-        camel_name =  ''.join([i.capitalize() for i in (self.dataset_name).split('_')])
-        try:
-            self.data_module = getattr(importlib.import_module(
-                f'datasets.{self.dataset_name}'), camel_name)
-        except:
-            raise ValueError(
-                'Invalid Dataset File Name or Invalid Class Name!')
+#     def load_data_module(self):
+#         camel_name =  ''.join([i.capitalize() for i in (self.dataset_name).split('_')])
+#         try:
+#             self.data_module = getattr(importlib.import_module(
+#                 f'datasets.{self.dataset_name}'), camel_name)
+#         except:
+#             raise ValueError(
+#                 'Invalid Dataset File Name or Invalid Class Name!')
     
-    def instancialize(self, **other_args):
-        """ Instancialize a model using the corresponding parameters
-            from self.hparams dictionary. You can also input any args
-            to overwrite the corresponding value in self.kwargs.
-        """
-        class_args = inspect.getargspec(self.data_module.__init__).args[1:]
-        inkeys = self.kwargs.keys()
-        args1 = {}
-        for arg in class_args:
-            if arg in inkeys:
-                args1[arg] = self.kwargs[arg]
-        args1.update(other_args)
-        return self.data_module(**args1)
+#     def instancialize(self, **other_args):
+#         """ Instancialize a model using the corresponding parameters
+#             from self.hparams dictionary. You can also input any args
+#             to overwrite the corresponding value in self.kwargs.
+#         """
+#         class_args = inspect.getargspec(self.data_module.__init__).args[1:]
+#         inkeys = self.kwargs.keys()
+#         args1 = {}
+#         for arg in class_args:
+#             if arg in inkeys:
+#                 args1[arg] = self.kwargs[arg]
+#         args1.update(other_args)
+#         return self.data_module(**args1)
 
 class MILDataModule(pl.LightningDataModule):
 
-    def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, use_features=False, *args, **kwargs):
+    def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, use_features=False, mixup=False, aug=False, *args, **kwargs):
         super().__init__()
         self.data_root = data_root
         self.label_path = label_path
@@ -134,32 +138,37 @@ class MILDataModule(pl.LightningDataModule):
         self.num_bags_train = 200
         self.num_bags_test = 50
         self.seed = 1
+        self.mixup = mixup
+        self.aug = aug
 
 
+        self.class_weight = []
         self.cache = cache
         self.fe_transform = None
         if not use_features: 
-            self.base_dataloader = JPGMILDataloader
+            self.base_dataloader = JPGBagLoader
         else: 
-            
-            self.base_dataloader = ZarrFeatureBagLoader
+            self.base_dataloader = FeatureBagLoader
             self.cache = True
-        
-
 
     def setup(self, stage: Optional[str] = None) -> None:
         home = Path.cwd().parts[1]
 
         if stage in (None, 'fit'):
-            dataset = self.base_dataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes, cache=self.cache)
+            dataset = self.base_dataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes, cache=self.cache, mixup=self.mixup, aug=self.aug)
             # dataset = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes)
             print(len(dataset))
             a = int(len(dataset)* 0.8)
             b = int(len(dataset) - a)
             self.train_data, self.valid_data = random_split(dataset, [a, b])
 
+            # self.weights = self.get_weights(dataset)
+
+
+
         if stage in (None, 'test'):
-            self.test_data = self.base_dataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, data_cache_size=1)
+            self.test_data = self.base_dataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, cache=False)
+            print(len(self.test_data))
 
         return super().setup(stage=stage)
 
@@ -167,6 +176,7 @@ class MILDataModule(pl.LightningDataModule):
 
     def train_dataloader(self) -> DataLoader:
         # return DataLoader(self.train_data,  batch_size = self.batch_size, num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+        # return DataLoader(self.train_data,  batch_size = self.batch_size, sampler = WeightedRandomSampler(self.weights, len(self.weights)), num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True, 
         return DataLoader(self.train_data,  batch_size = self.batch_size, sampler=ImbalancedDatasetSampler(self.train_data), num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True, 
         #sampler=ImbalancedDatasetSampler(self.train_data)
     def val_dataloader(self) -> DataLoader:
@@ -174,6 +184,22 @@ class MILDataModule(pl.LightningDataModule):
     
     def test_dataloader(self) -> DataLoader:
         return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers)
+
+    def get_weights(self, dataset):
+
+        label_count = [0]*self.n_classes
+        labels = dataset.get_labels(np.arange(len(dataset)))
+        for i in labels:
+            label_count[i] += 1
+        weights_per_class = [0.] * self.n_classes
+        for i in range(self.n_classes):
+            weights_per_class[i] = float(len(labels) / float(label_count[i]))
+        weights_per_class = [i / sum(weights_per_class) for i in weights_per_class]
+        weights = [0.] * len(labels)
+        for i in range(len(labels)):
+            weights[i] = weights_per_class[labels[i]]
+
+        return torch.DoubleTensor(weights)
     
 
 class DataModule(pl.LightningDataModule):
diff --git a/code/datasets/feature_dataloader.py b/code/datasets/feature_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e9dbb567b7a791386e56a86127958df8e617e0b
--- /dev/null
+++ b/code/datasets/feature_dataloader.py
@@ -0,0 +1,394 @@
+import pandas as pd
+
+import numpy as np
+import torch
+from torch import Tensor
+from torch.autograd import Variable
+from torch.nn.functional import one_hot
+from torch.utils import data
+from torch.utils.data import random_split, DataLoader
+from torchsampler import ImbalancedDatasetSampler
+from torchvision import datasets, transforms
+import pandas as pd
+from sklearn.utils import shuffle
+from pathlib import Path
+from tqdm import tqdm
+import zarr
+import json
+import cv2
+from PIL import Image
+import h5py
+# from models import TransMIL
+
+
+
+class FeatureBagLoader(data.Dataset):
+    def __init__(self, file_path, label_path, mode, n_classes, cache=False, mixup=False, aug=False, data_cache_size=5000, max_bag_size=1000):
+        super().__init__()
+
+        self.data_info = []
+        self.data_cache = {}
+        self.slideLabelDict = {}
+        self.files = []
+        self.data_cache_size = data_cache_size
+        self.mode = mode
+        self.file_path = file_path
+        # self.csv_path = csv_path
+        self.label_path = label_path
+        self.n_classes = n_classes
+        self.max_bag_size = max_bag_size
+        self.drop_rate = 0.2
+        # self.min_bag_size = 120
+        self.empty_slides = []
+        self.corrupt_slides = []
+        self.cache = cache
+        self.mixup = mixup
+        self.aug = aug
+        
+        self.missing = []
+
+        home = Path.cwd().parts[1]
+        self.slide_patient_dict_path = f'/{home}/ylan/DeepGraft/training_tables/slide_patient_dict_an.json'
+        with open(self.slide_patient_dict_path, 'r') as f:
+            self.slide_patient_dict = json.load(f)
+
+        # read labels and slide_path from csv
+        with open(self.label_path, 'r') as f:
+            json_dict = json.load(f)
+            temp_slide_label_dict = json_dict[self.mode]
+            # print(len(temp_slide_label_dict))
+            for (x,y) in temp_slide_label_dict:
+                
+                x_name = Path(x).stem
+                x_path_list = [Path(self.file_path)/x]
+                # x_name = x.stem
+                # x_path_list = [Path(self.file_path)/ x for (x,y) in temp_slide_label_dict]
+                if self.aug:
+                    for i in range(5):
+                        aug_path = Path(self.file_path)/f'{x}_aug{i}'
+                        x_path_list.append(aug_path)
+
+                for x_path in x_path_list: 
+                    
+                    if x_path.exists():
+                        self.slideLabelDict[x_name] = y
+                        self.files.append(x_path)
+                    elif Path(str(x_path) + '.zarr').exists():
+                        self.slideLabelDict[x] = y
+                        self.files.append(str(x_path)+'.zarr')
+                    else:
+                        self.missing.append(x)
+                # print(x, y)
+                # x_complete_path = Path(self.file_path)/Path(x)
+                # for cohort in Path(self.file_path).iterdir():
+                #     # x_complete_path = Path(self.file_path) / cohort / 'FEATURES_RETCCL' / (str(x) + '.zarr')
+                #     # if self.mode == 'test': #set to test if using GAN output
+                #     #     x_path_list = [Path(self.file_path) / cohort / 'FE' / (str(x) + '.zarr')]
+                #     # else:
+                #     # x_path_list = [Path(self.file_path) / cohort / 'FEATURES' / (str(x))]
+                #     x_path_list = [Path(self.file_path) / cohort / 'FEATURES_RETCCL_2048' / (str(x))]
+                #     # if not self.mixup:
+                #     for i in range(5):
+                #         aug_path = Path(self.file_path) / cohort / 'FEATURES_RETCCL_2048' / (str(x) + f'_aug{i}')
+                #         if aug_path.exists():
+                #             x_path_list.append(aug_path)
+                #     # print(x_complete_path)
+                #     for x_path in x_path_list:
+                #         # print(x_path)
+                        
+                #         if x_path.exists():
+                #             # print(x_path)
+                #             # if len(list(x_complete_path.iterdir())) > self.min_bag_size:
+                #             # # print(x_complete_path)
+                #             self.slideLabelDict[x] = y
+                #             self.files.append(x_path)
+                #         elif Path(str(x_path) + '.zarr').exists():
+                #             self.slideLabelDict[x] = y
+                #             self.files.append(str(x_path)+'.zarr')
+                #         else:
+                #             self.missing.append(x)
+        
+        # mix in 10 Slides of Test data
+            # if 'test_mixin' in json_dict.keys():
+            #     test_slide_label_dict = json_dict['test']
+            #     for (x, y) in test_slide_label_dict:
+            #         x = Path(x).stem
+            #         for cohort in Path(self.file_path).iterdir():
+            #             x_path_list = [Path(self.file_path) / cohort / 'FEATURES_RETCCL_2048' / (str(x))]
+            #             for x_path in x_path_list:
+            #                 if x_path.exists():
+            #                     self.slideLabelDict[x] = y
+            #                     self.files.append(x_path)
+            #                     patient = self.slide_patient_dict[x]
+            #                 elif Path(str(x_path) + '.zarr').exists():
+            #                     self.slideLabelDict[x] = y
+            #                     self.files.append(str(x_path)+'.zarr')
+
+
+
+        
+
+        self.feature_bags = []
+        self.labels = []
+        self.wsi_names = []
+        self.coords = []
+        self.patients = []
+        if self.cache:
+            for t in tqdm(self.files):
+                # zarr_t = str(t) + '.zarr'
+                batch, label, (wsi_name, batch_coords, patient) = self.get_data(t)
+
+                # print(label)
+                self.labels.append(label)
+                self.feature_bags.append(batch)
+                self.wsi_names.append(wsi_name)
+                self.coords.append(batch_coords)
+                self.patients.append(patient)
+        
+
+    def get_data(self, file_path):
+        
+        batch_names=[] #add function for name_batch read out
+
+        wsi_name = Path(file_path).stem
+        if wsi_name.split('_')[-1][:3] == 'aug':
+            wsi_name = '_'.join(wsi_name.split('_')[:-1])
+        # if wsi_name in self.slideLabelDict:
+        label = self.slideLabelDict[wsi_name]
+        patient = self.slide_patient_dict[wsi_name]
+
+        if Path(file_path).suffix == '.zarr':
+            z = zarr.open(file_path, 'r')
+            np_bag = np.array(z['data'][:])
+            coords = np.array(z['coords'][:])
+        else:
+            with h5py.File(file_path, 'r') as hdf5_file:
+                np_bag = hdf5_file['features'][:]
+                coords = hdf5_file['coords'][:]
+
+        # np_bag = torch.load(file_path)
+        # z = zarr.open(file_path, 'r')
+        # np_bag = np.array(z['data'][:])
+        # np_bag = np.array(zarr.open(file_path, 'r')).astype(np.uint8)
+        # label = torch.as_tensor(label)
+        label = int(label)
+        wsi_bag = torch.from_numpy(np_bag)
+        batch_coords = torch.from_numpy(coords)
+
+        return wsi_bag, label, (wsi_name, batch_coords, patient)
+    
+    def get_labels(self, indices):
+        # for i in indices: 
+        #     print(self.labels[i])
+        return [self.labels[i] for i in indices]
+
+
+    def to_fixed_size_bag(self, bag, names, bag_size: int = 512):
+
+        #duplicate bag instances unitl 
+
+        bag_idxs = torch.randperm(bag.shape[0])[:bag_size]
+        bag_samples = bag[bag_idxs]
+        name_samples = [names[i] for i in bag_idxs]
+        # bag_sample_names = [bag_names[i] for i in bag_idxs]
+        # q, r  = divmod(bag_size, bag_samples.shape[0])
+        # if q > 0:
+        #     bag_samples = torch.cat([bag_samples]*q, 0)
+
+        # self_padded = torch.cat([bag_samples, bag_samples[:r,:, :, :]])
+
+        # zero-pad if we don't have enough samples
+        # zero_padded = torch.cat((bag_samples,
+        #                         torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+
+        return bag_samples, name_samples, min(bag_size, len(bag))
+
+    def data_dropout(self, bag, batch_names, drop_rate):
+        # bag_size = self.max_bag_size
+        bag_size = bag.shape[0]
+        bag_idxs = torch.randperm(self.max_bag_size)[:int(bag_size*(1-drop_rate))]
+        bag_samples = bag[bag_idxs]
+        name_samples = [batch_names[i] for i in bag_idxs]
+
+        return bag_samples, name_samples
+
+    def get_mixup_bag(self, bag):
+
+        bag_size = bag.shape[0]
+
+        a = torch.rand([bag_size])
+        b = 0.6
+        rand_x = torch.randint(0, bag_size, [bag_size,])
+        rand_y = torch.randint(0, bag_size, [bag_size,])
+
+        bag_x = bag[rand_x, :]
+        bag_y = bag[rand_y, :]
+
+        # print('bag_x: ', bag_x.shape)
+        # print('bag_y: ', bag_y.shape)
+        # print('a*bag_x: ', (a*bag_x).shape)
+        # print('(1.0-a)*bag_y: ', ((1.0-a)*bag_y).shape)
+
+        temp_bag = (bag_x.t()*a).t() + (bag_y.t()*(1.0-a)).t()
+        # print('temp_bag: ', temp_bag.shape)
+
+        if bag_size < self.max_bag_size:
+            diff = self.max_bag_size - bag_size
+            bag_idxs = torch.randperm(bag_size)[:diff]
+            
+            # print('bag: ', bag.shape)
+            # print('bag_idxs: ', bag_idxs.shape)
+            mixup_bag = torch.cat((bag, temp_bag[bag_idxs, :]))
+            # print('mixup_bag: ', mixup_bag.shape)
+        else:
+            random_sample_list = torch.rand(bag_size)
+            mixup_bag = [bag[i] if random_sample_list[i] else temp_bag[i] > b for i in range(bag_size)] #make pytorch native?!
+            mixup_bag = torch.stack(mixup_bag)
+            # print('else')
+            # print(mixup_bag.shape)
+
+        return mixup_bag
+
+    def __len__(self):
+        return len(self.files)
+
+    def __getitem__(self, index):
+
+        if self.cache:
+            label = self.labels[index]
+            bag = self.feature_bags[index]
+            
+        
+            
+            # label = Variable(Tensor(label))
+            # label = torch.as_tensor(label)
+            # label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+            wsi_name = self.wsi_names[index]
+            batch_coords = self.coords[index]
+            patient = self.patients[index]
+
+            
+            #random dropout
+            #shuffle
+
+            # feats = Variable(Tensor(feats))
+            # return wsi, label, (wsi_name, batch_coords, patient)
+        else:
+            t = self.files[index]
+            bag, label, (wsi_name, batch_coords, patient) = self.get_data(t)
+            # label = torch.as_tensor(label)
+            # label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+                # self.labels.append(label)
+                # self.feature_bags.append(batch)
+                # self.wsi_names.append(wsi_name)
+                # self.name_batches.append(name_batch)
+                # self.patients.append(patient)
+        if self.mode == 'train':
+            bag_size = bag.shape[0]
+
+            bag_idxs = torch.randperm(bag_size)[:self.max_bag_size]
+            # bag_idxs = torch.randperm(bag_size)[:int(self.max_bag_size*(1-self.drop_rate))]
+            out_bag = bag[bag_idxs, :]
+            if self.mixup:
+                out_bag = self.get_mixup_bag(out_bag)
+                # batch_coords = 
+            if out_bag.shape[0] < self.max_bag_size:
+                out_bag = torch.cat((out_bag, torch.zeros(self.max_bag_size-out_bag.shape[0], out_bag.shape[1])))
+
+            # shuffle again
+            out_bag_idxs = torch.randperm(out_bag.shape[0])
+            out_bag = out_bag[out_bag_idxs]
+
+
+            # batch_coords only useful for test
+            batch_coords = batch_coords[bag_idxs]
+            
+
+        # mixup? Linear combination of 2 vectors
+        # add noise
+
+
+        else: out_bag = bag
+
+        return out_bag, label, (wsi_name, batch_coords, patient)
+
+if __name__ == '__main__':
+    
+    from pathlib import Path
+    import os
+    import time
+    # from fast_tensor_dl import FastTensorDataLoader
+    # from custom_resnet50 import resnet50_baseline
+    
+    
+
+    home = Path.cwd().parts[1]
+    train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
+    data_root = f'/{home}/ylan/data/DeepGraft/224_128uM_annotated'
+    # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
+    # label_path = f'/{home}/ylan/DeepGraft/training_tables/split_debug.json'
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest_test.json'
+    output_dir = f'/{data_root}/debug/augments'
+    os.makedirs(output_dir, exist_ok=True)
+
+    n_classes = 2
+
+    dataset = FeatureBagLoader(data_root, label_path=label_path, mode='train', cache=False, mixup=True, aug=True, n_classes=n_classes)
+
+    test_dataset = FeatureBagLoader(data_root, label_path=label_path, mode='test', cache=False, n_classes=n_classes)
+
+    # print(dataset.get_labels(0))
+    a = int(len(dataset)* 0.8)
+    b = int(len(dataset) - a)
+    train_data, valid_data = random_split(dataset, [a, b])
+
+    train_dl = DataLoader(train_data, batch_size=1, num_workers=5)
+    valid_dl = DataLoader(valid_data, batch_size=1, num_workers=5)
+    test_dl = DataLoader(test_dataset)
+
+    print('train_dl: ', len(train_dl))
+    print('valid_dl: ', len(valid_dl))
+    print('test_dl: ', len(test_dl))
+
+
+    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    # scaler = torch.cuda.amp.GradScaler()
+
+    # model_ft = resnet50_baseline(pretrained=True)
+    # for param in model_ft.parameters():
+    #     param.requires_grad = False
+    # model_ft.to(device)
+    # model = TransMIL(n_classes=n_classes).to(device)
+    
+
+    # print(dataset.get_labels(np.arange(len(dataset))))
+
+    c = 0
+    label_count = [0] *n_classes
+    epochs = 1
+    # print(len(dl))
+    # start = time.time()
+    for i in range(epochs):
+        start = time.time()
+        for item in tqdm(train_dl): 
+
+            # if c >= 10:
+            #     break
+            bag, label, (name, batch_coords, patient) = item
+            print(bag.shape)
+            # print(bag.shape, label)
+            # print(len(batch_names))
+            # print(label)
+            # print(batch_coords)
+            # print(name)
+            # bag = bag.float().to(device)
+            # print(bag.shape)
+            # label = label.to(device)
+            # with torch.cuda.amp.autocast():
+            #     output = model(bag)
+            # c += 1
+        end = time.time()
+        print('Bag Time: ', end-start)
+
+    
\ No newline at end of file
diff --git a/code/datasets/feature_dataloader_deca.py b/code/datasets/feature_dataloader_deca.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0543c5343593c074ac15b89778ab82b86db04c2
--- /dev/null
+++ b/code/datasets/feature_dataloader_deca.py
@@ -0,0 +1,320 @@
+import pandas as pd
+
+import numpy as np
+import torch
+from torch import Tensor
+from torch.autograd import Variable
+from torch.nn.functional import one_hot
+from torch.utils import data
+from torch.utils.data import random_split, DataLoader
+from torchsampler import ImbalancedDatasetSampler
+from torchvision import datasets, transforms
+import pandas as pd
+from sklearn.utils import shuffle
+from pathlib import Path
+from tqdm import tqdm
+import zarr
+import json
+import cv2
+from PIL import Image
+import h5py
+# from models import TransMIL
+
+
+
+class FeatureBagLoader(data.Dataset):
+    def __init__(self, file_path, label_path, mode, n_classes, cache=False, data_cache_size=5000, max_bag_size=1000):
+        super().__init__()
+
+        self.data_info = []
+        self.data_cache = {}
+        self.slideLabelDict = {}
+        self.files = []
+        self.data_cache_size = data_cache_size
+        self.mode = mode
+        self.file_path = file_path
+        # self.csv_path = csv_path
+        self.label_path = label_path
+        self.n_classes = n_classes
+        self.max_bag_size = max_bag_size
+        self.drop_rate = 0.2
+        # self.min_bag_size = 120
+        self.empty_slides = []
+        self.corrupt_slides = []
+        self.cache = cache
+        
+        self.missing = []
+
+        home = Path.cwd().parts[1]
+        self.slide_patient_dict_path = f'/{home}/ylan/DeepGraft/training_tables/slide_patient_dict_an.json'
+        with open(self.slide_patient_dict_path, 'r') as f:
+            self.slide_patient_dict = json.load(f)
+
+        # read labels and slide_path from csv
+        with open(self.label_path, 'r') as f:
+            json_dict = json.load(f)
+            temp_slide_label_dict = json_dict[self.mode]
+            # print(len(temp_slide_label_dict))
+            for (x, y) in temp_slide_label_dict:
+                x = Path(x).stem
+                
+                # print(x, y)
+                # x_complete_path = Path(self.file_path)/Path(x)
+                for cohort in Path(self.file_path).iterdir():
+                    # x_complete_path = Path(self.file_path) / cohort / 'FEATURES_RETCCL' / (str(x) + '.zarr')
+                    # if self.mode == 'test': #set to test if using GAN output
+                    #     x_path_list = [Path(self.file_path) / cohort / 'FE' / (str(x) + '.zarr')]
+                    # else:
+                    # x_path_list = [Path(self.file_path) / cohort / 'FEATURES' / (str(x))]
+                    x_path_list = [Path(self.file_path) / cohort / 'FEATURES_RETCCL_2048' / (str(x))]
+                    for i in range(5):
+                        aug_path = Path(self.file_path) / cohort / 'FEATURES_RETCCL_2048' / (str(x) + f'_aug{i}')
+                        if aug_path.exists():
+                            x_path_list.append(aug_path)
+                    # print(x_complete_path)
+                    for x_path in x_path_list:
+                        # print(x_path)
+                        
+                        if x_path.exists():
+                            # print(x_path)
+                            # if len(list(x_complete_path.iterdir())) > self.min_bag_size:
+                            # # print(x_complete_path)
+                            self.slideLabelDict[x] = y
+                            self.files.append(x_path)
+                        elif Path(str(x_path) + '.zarr').exists():
+                            self.slideLabelDict[x] = y
+                            self.files.append(str(x_path)+'.zarr')
+                        else:
+                            self.missing.append(x)
+        
+        # mix in 10 Slides of Test data
+            if 'test_mixin' in json_dict.keys():
+                test_slide_label_dict = json_dict['test']
+                for (x, y) in test_slide_label_dict:
+                    x = Path(x).stem
+                    for cohort in Path(self.file_path).iterdir():
+                        x_path_list = [Path(self.file_path) / cohort / 'FEATURES_RETCCL_2048' / (str(x))]
+                        for x_path in x_path_list:
+                            if x_path.exists():
+                                self.slideLabelDict[x] = y
+                                self.files.append(x_path)
+                                patient = self.slide_patient_dict[x]
+                            elif Path(str(x_path) + '.zarr').exists():
+                                self.slideLabelDict[x] = y
+                                self.files.append(str(x_path)+'.zarr')
+
+
+
+        
+
+        self.feature_bags = []
+        self.labels = []
+        self.wsi_names = []
+        self.coords = []
+        self.patients = []
+        if self.cache:
+            for t in tqdm(self.files):
+                # zarr_t = str(t) + '.zarr'
+                batch, label, (wsi_name, batch_coords, patient) = self.get_data(t)
+
+                # print(label)
+                self.labels.append(label)
+                self.feature_bags.append(batch)
+                self.wsi_names.append(wsi_name)
+                self.coords.append(batch_coords)
+                self.patients.append(patient)
+        
+
+    def get_data(self, file_path):
+        
+        batch_names=[] #add function for name_batch read out
+
+        wsi_name = Path(file_path).stem
+        if wsi_name.split('_')[-1][:3] == 'aug':
+            wsi_name = '_'.join(wsi_name.split('_')[:-1])
+        # if wsi_name in self.slideLabelDict:
+        label = self.slideLabelDict[wsi_name]
+        patient = self.slide_patient_dict[wsi_name]
+
+        if Path(file_path).suffix == '.zarr':
+            z = zarr.open(file_path, 'r')
+            np_bag = np.array(z['data'][:])
+            coords = np.array(z['coords'][:])
+        else:
+            with h5py.File(file_path, 'r') as hdf5_file:
+                np_bag = hdf5_file['features'][:]
+                coords = hdf5_file['coords'][:]
+
+        # np_bag = torch.load(file_path)
+        # z = zarr.open(file_path, 'r')
+        # np_bag = np.array(z['data'][:])
+        # np_bag = np.array(zarr.open(file_path, 'r')).astype(np.uint8)
+        # label = torch.as_tensor(label)
+        label = int(label)
+        wsi_bag = torch.from_numpy(np_bag)
+        batch_coords = torch.from_numpy(coords)
+
+        return wsi_bag, label, (wsi_name, batch_coords, patient)
+    
+    def get_labels(self, indices):
+        # for i in indices: 
+        #     print(self.labels[i])
+        return [self.labels[i] for i in indices]
+
+
+    def to_fixed_size_bag(self, bag, names, bag_size: int = 512):
+
+        #duplicate bag instances unitl 
+
+        bag_idxs = torch.randperm(bag.shape[0])[:bag_size]
+        bag_samples = bag[bag_idxs]
+        name_samples = [names[i] for i in bag_idxs]
+        # bag_sample_names = [bag_names[i] for i in bag_idxs]
+        # q, r  = divmod(bag_size, bag_samples.shape[0])
+        # if q > 0:
+        #     bag_samples = torch.cat([bag_samples]*q, 0)
+
+        # self_padded = torch.cat([bag_samples, bag_samples[:r,:, :, :]])
+
+        # zero-pad if we don't have enough samples
+        # zero_padded = torch.cat((bag_samples,
+        #                         torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+
+        return bag_samples, name_samples, min(bag_size, len(bag))
+
+    def data_dropout(self, bag, batch_names, drop_rate):
+        # bag_size = self.max_bag_size
+        # bag_size = bag.shape[0]
+        bag_idxs = torch.randperm(self.max_bag_size)[:int(bag_size*(1-drop_rate))]
+        bag_samples = bag[bag_idxs]
+        name_samples = [batch_names[i] for i in bag_idxs]
+
+        return bag_samples, name_samples
+
+    def __len__(self):
+        return len(self.files)
+
+    def __getitem__(self, index):
+
+        if self.cache:
+            label = self.labels[index]
+            bag = self.feature_bags[index]
+            
+        
+            
+            # label = Variable(Tensor(label))
+            # label = torch.as_tensor(label)
+            # label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+            wsi_name = self.wsi_names[index]
+            batch_coords = self.coords[index]
+            patient = self.patients[index]
+
+            
+            #random dropout
+            #shuffle
+
+            # feats = Variable(Tensor(feats))
+            # return wsi, label, (wsi_name, batch_coords, patient)
+        else:
+            t = self.files[index]
+            bag, label, (wsi_name, batch_coords, patient) = self.get_data(t)
+            # label = torch.as_tensor(label)
+            # label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+                # self.labels.append(label)
+                # self.feature_bags.append(batch)
+                # self.wsi_names.append(wsi_name)
+                # self.name_batches.append(name_batch)
+                # self.patients.append(patient)
+        if self.mode != 'test':
+            bag_size = bag.shape[0]
+            
+            bag_idxs = torch.randperm(bag_size)[:int(self.max_bag_size*(1-self.drop_rate))]
+            out_bag = bag[bag_idxs, :]
+            batch_coords = batch_coords[bag_idxs]
+            if out_bag.shape[0] < self.max_bag_size:
+                out_bag = torch.cat((out_bag, torch.zeros(self.max_bag_size-out_bag.shape[0], out_bag.shape[1])))
+
+        else: out_bag = bag
+
+        return out_bag, label, (wsi_name, batch_coords, patient)
+
+if __name__ == '__main__':
+    
+    from pathlib import Path
+    import os
+    import time
+    # from fast_tensor_dl import FastTensorDataLoader
+    # from custom_resnet50 import resnet50_baseline
+    
+    
+
+    home = Path.cwd().parts[1]
+    train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
+    data_root = f'/{home}/ylan/data/DeepGraft/224_128uM_annotated'
+    # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
+    # label_path = f'/{home}/ylan/DeepGraft/training_tables/split_debug.json'
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/dg_decathlon_PAS_HE_Jones_norm_rest.json'
+    output_dir = f'/{data_root}/debug/augments'
+    os.makedirs(output_dir, exist_ok=True)
+
+    n_classes = 2
+
+    dataset = FeatureBagLoader(data_root, label_path=label_path, mode='train', cache=True, n_classes=n_classes)
+
+    test_dataset = FeatureBagLoader(data_root, label_path=label_path, mode='test', cache=False, n_classes=n_classes)
+
+    # print(dataset.get_labels(0))
+    a = int(len(dataset)* 0.8)
+    b = int(len(dataset) - a)
+    train_data, valid_data = random_split(dataset, [a, b])
+
+    train_dl = DataLoader(train_data, batch_size=1, sampler=ImbalancedDatasetSampler(train_data), num_workers=5)
+    valid_dl = DataLoader(valid_data, batch_size=1, num_workers=5)
+    test_dl = DataLoader(test_dataset)
+
+    print('train_dl: ', len(train_dl))
+    print('valid_dl: ', len(valid_dl))
+    print('test_dl: ', len(test_dl))
+
+
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    scaler = torch.cuda.amp.GradScaler()
+
+    # model_ft = resnet50_baseline(pretrained=True)
+    # for param in model_ft.parameters():
+    #     param.requires_grad = False
+    # model_ft.to(device)
+    # model = TransMIL(n_classes=n_classes).to(device)
+    
+
+    # print(dataset.get_labels(np.arange(len(dataset))))
+
+    c = 0
+    label_count = [0] *n_classes
+    epochs = 1
+    # print(len(dl))
+    # start = time.time()
+    for i in range(epochs):
+        start = time.time()
+        for item in tqdm(test_dl): 
+
+            # if c >= 10:
+            #     break
+            bag, label, (name, batch_coords, patient) = item
+            # print(bag.shape)
+            print(bag.shape, label)
+            # print(len(batch_names))
+            # print(label)
+            # print(batch_coords)
+            # print(name)
+            bag = bag.float().to(device)
+            # print(bag.shape)
+            # label = label.to(device)
+            # with torch.cuda.amp.autocast():
+            #     output = model(bag)
+            # c += 1
+        end = time.time()
+        print('Bag Time: ', end-start)
+
+    
\ No newline at end of file
diff --git a/code/datasets/feature_extractor.py b/code/datasets/feature_extractor.py
index 0057c8effb518dd68d00f33035081578d9615166..0f2f64d33cc5992b9b318b4f5988d388fe4c1461 100644
--- a/code/datasets/feature_extractor.py
+++ b/code/datasets/feature_extractor.py
@@ -11,6 +11,15 @@ import torchvision.transforms as transforms
 import torch.nn.functional as F
 import re
 from imgaug import augmenters as iaa
+import argparse
+
+def make_parse():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--augment', default=False, action='store_true')
+    parser.add_argument('--cohort', default='RU', type=str)
+    
+    args = parser.parse_args()
+    return args
 
 def chunker(seq, size):
     return (seq[pos:pos + size] for pos in range(0, len(seq), size))
@@ -21,11 +30,12 @@ def get_coords(batch_names): #ToDO: Change function for precise coords
     for tile_name in batch_names: 
         # print(tile_name)
         pos = re.findall(r'\((.*?)\)', tile_name)
-        x, y = pos[-1].split('_')
+        # pos = pos
+        x, y = pos[-1].replace('-', '_').split('_')
         coords.append((int(x),int(y)))
     return coords
 
-def augment(img):
+def iaa_augment(img):
 
     sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1")
     sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2")
@@ -53,13 +63,22 @@ def augment(img):
 
 if __name__ == '__main__':
 
+    torch.set_num_threads(8)
+    torch.manual_seed(2022)
+
+    args = make_parse()
+    
+    augment=args.augment
+    cohorts = [args.cohort]
+    print('Augment Data: ', augment)
+    print('Cohort: ', cohorts)
 
     home = Path.cwd().parts[1]
     
-    data_root = Path(f'/{home}/ylan/data/DeepGraft/224_128um_v2')
+    data_root = Path(f'/{home}/ylan/data/DeepGraft/224_128uM_annotated')
     # output_path = Path(f'/{home}/ylan/wsi_tools/debug/zarr')
-    cohorts = ['DEEPGRAFT_RA', 'Leuven'] #, 
-    # cohorts = ['Aachen_Biopsy_Slides', 'DEEPGRAFT_RU'] #, 
+    # cohorts = ['RU', 'RA'] #, 
+    # cohorts = ['Aachen_Biopsy_Slides'] #, 
     # cohorts = ['Aachen_Biopsy_Slides', 'DEEPGRAFT_RU', 'DEEPGRAFT_RA', 'Leuven'] #, 
     compressor = Blosc(cname='blosclz', clevel=3)
 
@@ -77,28 +96,34 @@ if __name__ == '__main__':
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
     scaler = torch.cuda.amp.GradScaler()
     n_classes = 2
-    out_features = 1024
-    model_ft = ResNet.resnet50(num_classes=n_classes, mlp=False, two_branch=False, normlinear=True)
+    # out_features = 1024
+    model_ft = ResNet.resnet50(num_classes=1024, mlp=False, two_branch=False, normlinear=True)
+    
+    model_ft.fc = nn.Identity()
+    # print(model_ft)
+    # model_ft.fc = nn.Linear(2048, out_features)
     home = Path.cwd().parts[1]
-    model_ft.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=False)
-    for param in model_ft.parameters():
-        param.requires_grad = False
-    for m in model_ft.modules():
-        if isinstance(m, torch.nn.modules.batchnorm.BatchNorm2d):
-            m.eval()
-            m.weight.requires_grad = False
-            m.bias.requires_grad = False
-    model_ft.fc = nn.Linear(2048, out_features)
+    model_ft.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=True)
+    # for param in model_ft.parameters():
+    #     param.requires_grad = False
+    # for m in model_ft.modules():
+    #     if isinstance(m, torch.nn.modules.batchnorm.BatchNorm2d):
+    #         m.eval()
+    #         m.weight.requires_grad = False
+    #         m.bias.requires_grad = False
+    # model_ft.fc = nn.Linear(2048, out_features)
     model_ft.eval()
     model_ft.to(device)
 
     batch_size = 100
 
+
     for f in data_root.iterdir():
         
         if f.stem in cohorts:
-            print(f)
-            fe_path = f / 'FEATURES_RETCCL'
+            # fe_path = Path(test_output_path) / 'FEATURES_RETCCL'
+            fe_path = f / 'FEATURES_RETCCL_2048'
+
             fe_path.mkdir(parents=True, exist_ok=True)
             
             # num_files = len(list((f / 'BLOCKS').iterdir()))
@@ -107,21 +132,27 @@ if __name__ == '__main__':
                 if Path(slide).is_dir(): 
                     if slide.suffix != '.zarr':
                         slide_list.append(slide)
+            if augment:
+                tqdm_len = len(slide_list)*5
+            else: tqdm_len = len(slide_list)
 
-            with tqdm(total=len(slide_list)) as pbar:
+            
+            with tqdm(total=tqdm_len) as pbar:
                 for slide in slide_list:
-                    # print('slide: ', slide)
 
-                    # run every slide 5 times for augments
-                    for n in range(5):
+                    
 
 
-                        output_path = fe_path / Path(str(slide.stem) + f'_aug{n}.zarr')
-                        if output_path.is_dir():
-                                pbar.update(1)
-                                print(output_path, ' skipped.')
-                                continue
-                        # else:
+                    # print('slide: ', slide)
+
+                    # run every slide 5 times for augments
+                    if not augment:
+                        output_path = fe_path / Path(str(slide.stem) + '.zarr')
+                        # if output_path.is_dir():
+                        #     pbar.update(1)
+                        #     print(output_path, ' skipped.')
+                        #     continue
+                            # else:
                         output_array = []
                         output_batch_names = []
                         for tile_path_batch in chunker(list(slide.iterdir()), batch_size):
@@ -130,9 +161,6 @@ if __name__ == '__main__':
                             for t in tile_path_batch:
                                 # for n in range(5):
                                 img = np.asarray(Image.open(str(t))).astype(np.uint8) #.astype(np.uint8)
-
-                                img = augment(img)
-
                                 img = val_transforms(img.copy()).to(device)
                                 batch_array.append(img)
 
@@ -142,8 +170,8 @@ if __name__ == '__main__':
                                 continue
                             else:
                                 batch_array = torch.stack(batch_array) 
-                                # with torch.cuda.amp.autocast():
-                                model_output = model_ft(batch_array).detach()
+                                with torch.cuda.amp.autocast():
+                                    model_output = model_ft(batch_array).detach()
                                 output_array.append(model_output)
                                 output_batch_names += batch_names 
                         if len(output_array) == 0:
@@ -152,35 +180,60 @@ if __name__ == '__main__':
                         else:
                             output_array = torch.cat(output_array, dim=0).cpu().numpy()
                             output_batch_coords = get_coords(output_batch_names)
-                            # print(output_batch_coords)
-                            # z = zarr.group()
-                            # data = z.create_group('data')
-                            # tile_names = z.create_group('tile_names')
-                            # d1 = data.create_dataset('bag', shape=output_array.shape, chunks=True, compressor = compressor, synchronizer=zarr.ThreadSynchronizer(), dtype='i4')
-                            # d2 = tile_names.create_dataset(output_batch_coords, shape=[len(output_batch_coords), 2], chunks=True, compressor = compressor, synchronizer=zarr.ThreadSynchronizer(), dtype='i4')
-
-                            # z['data'] = zarr.array(output_array, chunks=True, compressor = compressor, synchronizer=zarr.ThreadSynchronizer(), dtype='float') # 1792 = 224*8
-                            # z['tile_names'] = zarr.array(output_batch_coords, chunks=True, compressor = compressor, synchronizer=zarr.ThreadSynchronizer(), dtype='int32') # 1792 = 224*8
-                            # z.save
-                            # print(z['data'])
-                            # print(z['data'][:])
-                            # print(z['tile_names'][:])
-                            # zarr.save(output_path, z)
                             zarr.save_group(output_path, data=output_array, coords=output_batch_coords)
 
+                            # test eval mode!
                             # z_test = zarr.open(output_path, 'r')
-                            # print(z_test['data'][:])
                             # # print(z_test.tree())
                             
                             # if np.all(output_array== z_test['data'][:]):
-                            #     print('data true')
+                            #     print('data same')
+                            # else: print(slide)
                             # if np.all(z['tile_names'][:] == z_test['tile_names'][:]):
                             #     print('tile_names true')
                             #     print(output_path ' ')
                             # print(np.all(z[:] == z_test[:]))
-
-                        # np.save(f'{str(slide)}.npy', slide_np)
                             pbar.update(1)
+                    else:
+                        for n in range(5):
+                            # if n != 5:
+                            output_path = fe_path / Path(str(slide.stem) + f'_aug{n}.zarr')
+                            if output_path.is_dir():
+                                pbar.update(1)
+                                # print(output_path, ' skipped.')
+                                continue
+                            # else:
+                            output_array = []
+                            output_batch_names = []
+                            for tile_path_batch in chunker(list(slide.iterdir()), batch_size):
+                                batch_array = []
+                                batch_names = []
+                                for t in tile_path_batch:
+                                    # for n in range(5):
+                                    img = np.asarray(Image.open(str(t))).astype(np.uint8) #.astype(np.uint8)
+                                    img = iaa_augment(img)
+                                    img = val_transforms(img.copy()).to(device)
+                                    batch_array.append(img)
+
+                                    tile_name = t.stem
+                                    batch_names.append(tile_name)
+                                if len(batch_array) == 0:
+                                    continue
+                                else:
+                                    batch_array = torch.stack(batch_array) 
+                                    with torch.cuda.amp.autocast():
+                                        model_output = model_ft(batch_array).detach()
+                                    output_array.append(model_output)
+                                    output_batch_names += batch_names 
+                            if len(output_array) == 0:
+                                pbar.update(1)
+                                continue
+                            else:
+                                output_array = torch.cat(output_array, dim=0).cpu().numpy()
+                                output_batch_coords = get_coords(output_batch_names)
+                                zarr.save_group(output_path, data=output_array, coords=output_batch_coords)
+
+                                pbar.update(1)
             
 
                   
\ No newline at end of file
diff --git a/code/datasets/feature_extractor_2.py b/code/datasets/feature_extractor_2.py
deleted file mode 100644
index 0057c8effb518dd68d00f33035081578d9615166..0000000000000000000000000000000000000000
--- a/code/datasets/feature_extractor_2.py
+++ /dev/null
@@ -1,186 +0,0 @@
-import numpy as np
-from pathlib import Path
-from PIL import Image
-from tqdm import tqdm
-import zarr
-from numcodecs import Blosc
-import torch
-import torch.nn as nn
-import ResNet as ResNet 
-import torchvision.transforms as transforms
-import torch.nn.functional as F
-import re
-from imgaug import augmenters as iaa
-
-def chunker(seq, size):
-    return (seq[pos:pos + size] for pos in range(0, len(seq), size))
-
-def get_coords(batch_names): #ToDO: Change function for precise coords
-    coords = []
-    
-    for tile_name in batch_names: 
-        # print(tile_name)
-        pos = re.findall(r'\((.*?)\)', tile_name)
-        x, y = pos[-1].split('_')
-        coords.append((int(x),int(y)))
-    return coords
-
-def augment(img):
-
-    sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1")
-    sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2")
-    sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3")
-    sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4")
-    sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5")
-
-    transforms = iaa.Sequential([
-        iaa.AddToHueAndSaturation(value=(-30, 30), name="MyHSV"), #13
-        sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")),
-        iaa.Fliplr(0.5, name="MyFlipLR"),
-        iaa.Flipud(0.5, name="MyFlipUD"),
-        sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")),
-        iaa.OneOf([
-            sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")),
-            sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")),
-            sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine"))
-        ], name="MyOneOf")
-    ])
-    seq_img_d = transforms.to_deterministic()
-    img = seq_img_d.augment_image(img)
-
-    return img
-
-
-if __name__ == '__main__':
-
-
-    home = Path.cwd().parts[1]
-    
-    data_root = Path(f'/{home}/ylan/data/DeepGraft/224_128um_v2')
-    # output_path = Path(f'/{home}/ylan/wsi_tools/debug/zarr')
-    cohorts = ['DEEPGRAFT_RA', 'Leuven'] #, 
-    # cohorts = ['Aachen_Biopsy_Slides', 'DEEPGRAFT_RU'] #, 
-    # cohorts = ['Aachen_Biopsy_Slides', 'DEEPGRAFT_RU', 'DEEPGRAFT_RA', 'Leuven'] #, 
-    compressor = Blosc(cname='blosclz', clevel=3)
-
-    val_transforms = transforms.Compose([
-            # 
-            transforms.ToTensor(),
-            transforms.Normalize(
-                mean=[0.485, 0.456, 0.406],
-                std=[0.229, 0.224, 0.225],
-            ),
-            # RangeNormalization(),
-        ])
-
-
-    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-    scaler = torch.cuda.amp.GradScaler()
-    n_classes = 2
-    out_features = 1024
-    model_ft = ResNet.resnet50(num_classes=n_classes, mlp=False, two_branch=False, normlinear=True)
-    home = Path.cwd().parts[1]
-    model_ft.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=False)
-    for param in model_ft.parameters():
-        param.requires_grad = False
-    for m in model_ft.modules():
-        if isinstance(m, torch.nn.modules.batchnorm.BatchNorm2d):
-            m.eval()
-            m.weight.requires_grad = False
-            m.bias.requires_grad = False
-    model_ft.fc = nn.Linear(2048, out_features)
-    model_ft.eval()
-    model_ft.to(device)
-
-    batch_size = 100
-
-    for f in data_root.iterdir():
-        
-        if f.stem in cohorts:
-            print(f)
-            fe_path = f / 'FEATURES_RETCCL'
-            fe_path.mkdir(parents=True, exist_ok=True)
-            
-            # num_files = len(list((f / 'BLOCKS').iterdir()))
-            slide_list = []
-            for slide in (f / 'BLOCKS').iterdir():
-                if Path(slide).is_dir(): 
-                    if slide.suffix != '.zarr':
-                        slide_list.append(slide)
-
-            with tqdm(total=len(slide_list)) as pbar:
-                for slide in slide_list:
-                    # print('slide: ', slide)
-
-                    # run every slide 5 times for augments
-                    for n in range(5):
-
-
-                        output_path = fe_path / Path(str(slide.stem) + f'_aug{n}.zarr')
-                        if output_path.is_dir():
-                                pbar.update(1)
-                                print(output_path, ' skipped.')
-                                continue
-                        # else:
-                        output_array = []
-                        output_batch_names = []
-                        for tile_path_batch in chunker(list(slide.iterdir()), batch_size):
-                            batch_array = []
-                            batch_names = []
-                            for t in tile_path_batch:
-                                # for n in range(5):
-                                img = np.asarray(Image.open(str(t))).astype(np.uint8) #.astype(np.uint8)
-
-                                img = augment(img)
-
-                                img = val_transforms(img.copy()).to(device)
-                                batch_array.append(img)
-
-                                tile_name = t.stem
-                                batch_names.append(tile_name)
-                            if len(batch_array) == 0:
-                                continue
-                            else:
-                                batch_array = torch.stack(batch_array) 
-                                # with torch.cuda.amp.autocast():
-                                model_output = model_ft(batch_array).detach()
-                                output_array.append(model_output)
-                                output_batch_names += batch_names 
-                        if len(output_array) == 0:
-                            pbar.update(1)
-                            continue
-                        else:
-                            output_array = torch.cat(output_array, dim=0).cpu().numpy()
-                            output_batch_coords = get_coords(output_batch_names)
-                            # print(output_batch_coords)
-                            # z = zarr.group()
-                            # data = z.create_group('data')
-                            # tile_names = z.create_group('tile_names')
-                            # d1 = data.create_dataset('bag', shape=output_array.shape, chunks=True, compressor = compressor, synchronizer=zarr.ThreadSynchronizer(), dtype='i4')
-                            # d2 = tile_names.create_dataset(output_batch_coords, shape=[len(output_batch_coords), 2], chunks=True, compressor = compressor, synchronizer=zarr.ThreadSynchronizer(), dtype='i4')
-
-                            # z['data'] = zarr.array(output_array, chunks=True, compressor = compressor, synchronizer=zarr.ThreadSynchronizer(), dtype='float') # 1792 = 224*8
-                            # z['tile_names'] = zarr.array(output_batch_coords, chunks=True, compressor = compressor, synchronizer=zarr.ThreadSynchronizer(), dtype='int32') # 1792 = 224*8
-                            # z.save
-                            # print(z['data'])
-                            # print(z['data'][:])
-                            # print(z['tile_names'][:])
-                            # zarr.save(output_path, z)
-                            zarr.save_group(output_path, data=output_array, coords=output_batch_coords)
-
-                            # z_test = zarr.open(output_path, 'r')
-                            # print(z_test['data'][:])
-                            # # print(z_test.tree())
-                            
-                            # if np.all(output_array== z_test['data'][:]):
-                            #     print('data true')
-                            # if np.all(z['tile_names'][:] == z_test['tile_names'][:]):
-                            #     print('tile_names true')
-                            #     print(output_path ' ')
-                            # print(np.all(z[:] == z_test[:]))
-
-                        # np.save(f'{str(slide)}.npy', slide_np)
-                            pbar.update(1)
-            
-
-                  
\ No newline at end of file
diff --git a/code/datasets/feature_extractor_annotated.ipynb b/code/datasets/feature_extractor_annotated.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..2b34715d99116a1dabc39af1c23856690afc884a
--- /dev/null
+++ b/code/datasets/feature_extractor_annotated.ipynb
@@ -0,0 +1,223 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "from pathlib import Path\n",
+    "from PIL import Image\n",
+    "from tqdm import tqdm\n",
+    "import zarr\n",
+    "from numcodecs import Blosc\n",
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import ResNet as ResNet \n",
+    "import torchvision.transforms as transforms\n",
+    "import torch.nn.functional as F\n",
+    "import re\n",
+    "from imgaug import augmenters as iaa\n",
+    "import argparse\n",
+    "import json"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def make_parse():\n",
+    "    parser = argparse.ArgumentParser()\n",
+    "    parser.add_argument('--augment', default=False, action='store_true')\n",
+    "    \n",
+    "    args = parser.parse_args()\n",
+    "    return args\n",
+    "\n",
+    "def chunker(seq, size):\n",
+    "    return (seq[pos:pos + size] for pos in range(0, len(seq), size))\n",
+    "\n",
+    "def get_coords(batch_names): #ToDO: Change function for precise coords\n",
+    "    coords = []\n",
+    "    \n",
+    "    for tile_name in batch_names: \n",
+    "        # print(tile_name)\n",
+    "        pos = re.findall(r'\\((.*?)\\)', tile_name)\n",
+    "        x, y = pos[-1].split('_')\n",
+    "        coords.append((int(x),int(y)))\n",
+    "    return coords\n",
+    "\n",
+    "def iaa_augment(img):\n",
+    "\n",
+    "    sometimes = lambda aug: iaa.Sometimes(0.5, aug, name=\"Random1\")\n",
+    "    sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name=\"Random2\")\n",
+    "    sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name=\"Random3\")\n",
+    "    sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name=\"Random4\")\n",
+    "    sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name=\"Random5\")\n",
+    "\n",
+    "    transforms = iaa.Sequential([\n",
+    "        iaa.AddToHueAndSaturation(value=(-30, 30), name=\"MyHSV\"), #13\n",
+    "        sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name=\"MyGamma\")),\n",
+    "        iaa.Fliplr(0.5, name=\"MyFlipLR\"),\n",
+    "        iaa.Flipud(0.5, name=\"MyFlipUD\"),\n",
+    "        sometimes(iaa.Rot90(k=1, keep_size=True, name=\"MyRot90\")),\n",
+    "        iaa.OneOf([\n",
+    "            sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name=\"MyPiece\")),\n",
+    "            sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name=\"MyElastic\")),\n",
+    "            sometimes5(iaa.Affine(scale={\"x\": (0.95, 1.05), \"y\": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name=\"MyAffine\"))\n",
+    "        ], name=\"MyOneOf\")\n",
+    "    ])\n",
+    "    seq_img_d = transforms.to_deterministic()\n",
+    "    img = seq_img_d.augment_image(img)\n",
+    "\n",
+    "    return img\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "usage: ipykernel_launcher.py [-h] [--augment]\n",
+      "ipykernel_launcher.py: error: unrecognized arguments: --ip=127.0.0.1 --stdin=9040 --control=9038 --hb=9037 --Session.signature_scheme=\"hmac-sha256\" --Session.key=b\"fea321c1-c51d-4123-ac4e-21d7b6f0be68\" --shell=9039 --transport=\"tcp\" --iopub=9041 --f=/home/ylan/.local/share/jupyter/runtime/kernel-v2-8466dCu9m1xIy2SG.json\n"
+     ]
+    },
+    {
+     "ename": "SystemExit",
+     "evalue": "2",
+     "output_type": "error",
+     "traceback": [
+      "An exception has occurred, use %tb to see the full traceback.\n",
+      "\u001b[0;31mSystemExit\u001b[0m\u001b[0;31m:\u001b[0m 2\n"
+     ]
+    }
+   ],
+   "source": [
+    "torch.set_num_threads(8)\n",
+    "torch.manual_seed(2022)\n",
+    "\n",
+    "args = make_parse()\n",
+    "\n",
+    "augment=args.augment\n",
+    "print('Augment Data: ', augment)\n",
+    "\n",
+    "home = Path.cwd().parts[1]\n",
+    "data_root = Path(f'/{home}/ylan/data/DeepGraft/tissue_detection/224_128uM/images')\n",
+    "slide_patient_path = f'/{home}/ylan/DeepGraft/training_tables/slide_patient_dict.json'\n",
+    "cohort_stain_path = f'/{home}/ylan/DeepGraft/training_tables/cohort_stain_dict.json'\n",
+    "with open(slide_patient_path, 'r') as f:\n",
+    "    slide_patient_dict = json.load(f)\n",
+    "with open(cohort_stain_path, 'r') as f:\n",
+    "    cohort_stain_dict = json.load(f)\n",
+    "# output_path = Path(f'/{home}/ylan/wsi_tools/debug/zarr')\n",
+    "# cohorts = ['DEEPGRAFT_RU'] #, \n",
+    "# cohorts = ['Aachen_Biopsy_Slides'] #, \n",
+    "# cohorts = ['Aachen_Biopsy_Slides', 'DEEPGRAFT_RU', 'DEEPGRAFT_RA', 'Leuven'] #, \n",
+    "compressor = Blosc(cname='blosclz', clevel=3)\n",
+    "\n",
+    "val_transforms = transforms.Compose([\n",
+    "        # \n",
+    "        transforms.ToTensor(),\n",
+    "        transforms.Normalize(\n",
+    "            mean=[0.485, 0.456, 0.406],\n",
+    "            std=[0.229, 0.224, 0.225],\n",
+    "        ),\n",
+    "        # RangeNormalization(),\n",
+    "    ])\n",
+    "\n",
+    "\n",
+    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
+    "scaler = torch.cuda.amp.GradScaler()\n",
+    "n_classes = 2\n",
+    "# out_features = 1024\n",
+    "model_ft = ResNet.resnet50(num_classes=128, mlp=False, two_branch=False, normlinear=True)\n",
+    "\n",
+    "model_ft.fc = nn.Identity()\n",
+    "# print(model_ft)\n",
+    "# model_ft.fc = nn.Linear(2048, out_features)\n",
+    "home = Path.cwd().parts[1]\n",
+    "model_ft.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=True)\n",
+    "# for param in model_ft.parameters():\n",
+    "#     param.requires_grad = False\n",
+    "# for m in model_ft.modules():\n",
+    "#     if isinstance(m, torch.nn.modules.batchnorm.BatchNorm2d):\n",
+    "#         m.eval()\n",
+    "#         m.weight.requires_grad = False\n",
+    "#         m.bias.requires_grad = False\n",
+    "# model_ft.fc = nn.Linear(2048, out_features)\n",
+    "model_ft.eval()\n",
+    "model_ft.to(device)\n",
+    "\n",
+    "batch_size = 100"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "",
+     "evalue": "",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[1;31mRunning cells with 'Python 3.8.10 64-bit' requires ipykernel package.\n",
+      "\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
+      "\u001b[1;31mCommand: '/usr/bin/python3 -m pip install ipykernel -U --user --force-reinstall'"
+     ]
+    }
+   ],
+   "source": [
+    "patient_cohort_dict = {}\n",
+    "for cohort in cohort_stain_dict.keys():\n",
+    "    cohort_patient_list = list(cohort_stain_dict[cohort].keys())\n",
+    "    for patient in cohort_patient_list:\n",
+    "        patient_cohort_dict[patient] = cohort"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "for f in data_root.iterdir():\n",
+    "    slide_name = f.stem.split('_', 1)[0]\n",
+    "    patient = slide_patient_dict[slide_name]"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3.9.13 ('torch')",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.8"
+  },
+  "orig_nbformat": 4,
+  "vscode": {
+   "interpreter": {
+    "hash": "7b7fb95db5714bbf59d6a04f6057e8fa5746fef9d16f5c42f2fdbc713170171a"
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/code/datasets/feature_extractor_annotated.py b/code/datasets/feature_extractor_annotated.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6e6a78c2e37a51a88158e0003903b4b0366abcd
--- /dev/null
+++ b/code/datasets/feature_extractor_annotated.py
@@ -0,0 +1,267 @@
+import numpy as np
+from pathlib import Path
+from PIL import Image
+from tqdm import tqdm
+import zarr
+from numcodecs import Blosc
+import torch
+import torch.nn as nn
+import ResNet as ResNet 
+import torchvision.transforms as transforms
+import torch.nn.functional as F
+import re
+from imgaug import augmenters as iaa
+import argparse
+import json
+
+def make_parse():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--augment', default=False, action='store_true')
+    
+    args = parser.parse_args()
+    return args
+
+def chunker(seq, size):
+    return (seq[pos:pos + size] for pos in range(0, len(seq), size))
+
+def get_coords(batch_names): #ToDO: Change function for precise coords
+    coords = []
+    
+    for tile_name in batch_names: 
+        # print(tile_name)
+        pos = re.findall(r'\((.*?)\)', tile_name)
+        x, y = pos[-1].split('-')
+        coords.append((int(x),int(y)))
+    return coords
+
+def iaa_augment(img):
+
+    sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1")
+    sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2")
+    sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3")
+    sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4")
+    sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5")
+
+    transforms = iaa.Sequential([
+        iaa.AddToHueAndSaturation(value=(-30, 30), name="MyHSV"), #13
+        sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")),
+        iaa.Fliplr(0.5, name="MyFlipLR"),
+        iaa.Flipud(0.5, name="MyFlipUD"),
+        sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")),
+        iaa.OneOf([
+            sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")),
+            sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")),
+            sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine"))
+        ], name="MyOneOf")
+    ])
+    seq_img_d = transforms.to_deterministic()
+    img = seq_img_d.augment_image(img)
+
+    return img
+
+
+torch.set_num_threads(8)
+torch.manual_seed(2022)
+
+args = make_parse()
+
+augment=args.augment
+print('Augment Data: ', augment)
+
+home = Path.cwd().parts[1]
+data_root = Path(f'/{home}/ylan/data/DeepGraft/tissue_detection/224_128uM/training/images')
+output_dataset_path = Path(f'/{home}/ylan/data/DeepGraft/224_128uM_annotated/')
+slide_patient_path = f'/{home}/ylan/DeepGraft/training_tables/slide_patient_dict.json'
+cohort_stain_path = f'/{home}/ylan/DeepGraft/training_tables/cohort_stain_dict.json'
+
+with open(slide_patient_path, 'r') as f:
+    slide_patient_dict = json.load(f)
+with open(cohort_stain_path, 'r') as f:
+    cohort_stain_dict = json.load(f)
+# output_path = Path(f'/{home}/ylan/wsi_tools/debug/zarr')
+# cohorts = ['DEEPGRAFT_RU'] #, 
+# cohorts = ['Aachen_Biopsy_Slides'] #, 
+# cohorts = ['Aachen_Biopsy_Slides', 'DEEPGRAFT_RU', 'DEEPGRAFT_RA', 'Leuven'] #, 
+compressor = Blosc(cname='blosclz', clevel=3)
+
+val_transforms = transforms.Compose([
+        # 
+        transforms.ToTensor(),
+        transforms.Normalize(
+            mean=[0.485, 0.456, 0.406],
+            std=[0.229, 0.224, 0.225],
+        ),
+        # RangeNormalization(),
+    ])
+
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+scaler = torch.cuda.amp.GradScaler()
+n_classes = 2
+# out_features = 1024
+model_ft = ResNet.resnet50(num_classes=1024, mlp=False, two_branch=False, normlinear=True)
+
+model_ft.fc = nn.Identity()
+# print(model_ft)
+# model_ft.fc = nn.Linear(2048, out_features)
+home = Path.cwd().parts[1]
+model_ft.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=True)
+# for param in model_ft.parameters():
+#     param.requires_grad = False
+# for m in model_ft.modules():
+#     if isinstance(m, torch.nn.modules.batchnorm.BatchNorm2d):
+#         m.eval()
+#         m.weight.requires_grad = False
+#         m.bias.requires_grad = False
+# model_ft.fc = nn.Linear(2048, out_features)
+model_ft.eval()
+model_ft.to(device)
+
+batch_size = 100
+
+
+skipped_slides = []
+
+patient_cohort_dict = {}
+for cohort in cohort_stain_dict.keys():
+    cohort_patient_list = list(cohort_stain_dict[cohort].keys())
+    for patient in cohort_patient_list:
+        patient_cohort_dict[patient] = cohort
+
+print('patient_cohort_dict completed.')
+
+cohort_slide_tiles_dict = {}
+for tile in data_root.iterdir():
+    
+    slide_name = tile.stem.rsplit('_', 1)[0]
+    if slide_name in slide_patient_dict.keys():
+        patient = slide_patient_dict[slide_name]
+        cohort = patient_cohort_dict[patient]
+
+        if cohort not in cohort_slide_tiles_dict.keys():
+            cohort_slide_tiles_dict[cohort] = {}
+        if slide_name not in cohort_slide_tiles_dict[cohort].keys():
+            cohort_slide_tiles_dict[cohort][slide_name] = []
+        cohort_slide_tiles_dict[cohort][slide_name].append(tile)
+
+    else: 
+        skipped_slides.append(slide_name)
+
+print('cohort_slide_tiles_dict complete.')
+for cohort in cohort_slide_tiles_dict.keys():
+    print(cohort)
+    print(len(list(cohort_slide_tiles_dict[cohort].keys())))
+
+for c in cohort_slide_tiles_dict.keys():
+
+    fe_path = output_dataset_path / c / 'FEATURES_RETCCL_2048'
+    fe_path.mkdir(parents=True, exist_ok=True)
+
+    slide_list = list(cohort_slide_tiles_dict[c].keys())
+
+    if augment:
+        tqdm_len = len(slide_list)*5
+    else: tqdm_len = len(slide_list)
+
+    with tqdm(total=tqdm_len) as pbar:
+        for slide in slide_list:
+
+            # run every slide 5 times for augments
+            if not augment:
+                output_path = fe_path / Path(slide + '.zarr')
+                if output_path.is_dir():
+                    pbar.update(1)
+                    print(output_path, ' skipped.')
+                    continue
+                    # else:
+                output_array = []
+                output_batch_names = []
+                # for tile_path_batch in chunker(list(slide.iterdir()), batch_size):
+                for tile_path_batch in chunker(cohort_slide_tiles_dict[c][slide], batch_size):
+                    batch_array = []
+                    batch_names = []
+                    for t in tile_path_batch:
+                        # for n in range(5):
+                        img = np.asarray(Image.open(str(t))).astype(np.uint8) #.astype(np.uint8)
+                        img = val_transforms(img.copy()).to(device)
+                        batch_array.append(img)
+
+                        # print(t)
+                        # print(t.stem)
+                        tile_name = t.stem
+                        batch_names.append(tile_name)
+                    if len(batch_array) == 0:
+                        continue
+                    else:
+                        batch_array = torch.stack(batch_array) 
+                        with torch.cuda.amp.autocast():
+                            model_output = model_ft(batch_array).detach()
+                        output_array.append(model_output)
+                        output_batch_names += batch_names 
+                if len(output_array) == 0:
+                    pbar.update(1)
+                    continue
+                else:
+                    output_array = torch.cat(output_array, dim=0).cpu().numpy()
+                    output_batch_coords = get_coords(output_batch_names)
+                    zarr.save_group(output_path, data=output_array, coords=output_batch_coords)
+
+                    # test eval mode!
+                    z_test = zarr.open(output_path, 'r')
+                    # print(z_test.tree())
+                    
+                    if np.all(output_array  != z_test['data'][:]):
+                        # print('data same')
+                        print(slide)
+                    # if np.all(z['tile_names'][:] == z_test['tile_names'][:]):
+                    #     print('tile_names true')
+                    #     print(output_path ' ')
+                    # print(np.all(z[:] == z_test[:]))
+                    pbar.update(1)
+            else:
+                for n in range(4):
+
+                    # if n != 5:
+                    output_path = fe_path / Path(slide + '.zarr')
+                    # output_path = fe_path / Path(slide + f'_aug{n}.zarr')
+                    if output_path.is_dir():
+                        pbar.update(1)
+                        # print(output_path, ' skipped.')
+                        continue
+                    # else:
+                    output_array = []
+                    output_batch_names = []
+                    # for tile_path_batch in chunker(list(slide.iterdir()), batch_size):
+                    for tile_path_batch in chunker(cohort_slide_tiles_dict[c][slide], batch_size):
+                        batch_array = []
+                        batch_names = []
+                        for t in tile_path_batch:
+                            # for n in range(5):
+                            img = np.asarray(Image.open(str(t))).astype(np.uint8) #.astype(np.uint8)
+                            img = iaa_augment(img)
+                            img = val_transforms(img.copy()).to(device)
+                            batch_array.append(img)
+
+                            tile_name = t.stem
+                            batch_names.append(tile_name)
+                        if len(batch_array) == 0:
+                            continue
+                        else:
+                            batch_array = torch.stack(batch_array) 
+                            with torch.cuda.amp.autocast():
+                                model_output = model_ft(batch_array).detach()
+                            output_array.append(model_output)
+                            output_batch_names += batch_names 
+                    if len(output_array) == 0:
+                        pbar.update(1)
+                        continue
+                    else:
+                        output_array = torch.cat(output_array, dim=0).cpu().numpy()
+                        output_batch_coords = get_coords(output_batch_names)
+                        zarr.save_group(output_path, data=output_array, coords=output_batch_coords)
+
+                        pbar.update(1)
+
+
+print('skipped slides:')
+print(skipped_slides)
diff --git a/code/datasets/feature_file_checker.py b/code/datasets/feature_file_checker.py
new file mode 100644
index 0000000000000000000000000000000000000000..e594e4b3eea4ada09ff0db1b88ad8bd990af7f0b
--- /dev/null
+++ b/code/datasets/feature_file_checker.py
@@ -0,0 +1,82 @@
+import numpy as np
+from pathlib import Path
+from PIL import Image
+from tqdm import tqdm
+import zarr
+from numcodecs import Blosc
+import torch
+import torch.nn as nn
+import ResNet as ResNet 
+import torchvision.transforms as transforms
+import torch.nn.functional as F
+import re
+from imgaug import augmenters as iaa
+
+def chunker(seq, size):
+    return (seq[pos:pos + size] for pos in range(0, len(seq), size))
+
+def get_coords(batch_names): #ToDO: Change function for precise coords
+    coords = []
+    
+    for tile_name in batch_names: 
+        # print(tile_name)
+        pos = re.findall(r'\((.*?)\)', tile_name)
+        x, y = pos[-1].split('_')
+        coords.append((int(x),int(y)))
+    return coords
+
+
+
+if __name__ == '__main__':
+
+
+    home = Path.cwd().parts[1]
+    
+    data_root = Path(f'/{home}/ylan/data/DeepGraft/224_128um_v2')
+    # output_path = Path(f'/{home}/ylan/wsi_tools/debug/zarr')
+    # cohorts = ['Leuven'] #, 
+    cohorts = ['DEEPGRAFT_RU'] #, 
+    # cohorts = ['Aachen_Biopsy_Slides'] #, 
+    # cohorts = ['DEEPGRAFT_RU', 'DEEPGRAFT_RA', 'Leuven'] #, 
+    # cohorts = ['Aachen_Biopsy_Slides', 'DEEPGRAFT_RU', 'DEEPGRAFT_RA'] #, 
+    # cohorts = ['Aachen_Biopsy_Slides', 'DEEPGRAFT_RU', 'DEEPGRAFT_RA', 'Leuven'] #, 
+    
+    for f in data_root.iterdir():
+        
+        if f.stem in cohorts:
+            print(f)
+            fe_path = f / 'FEATURES_RETCCL'
+            fe_path.mkdir(parents=True, exist_ok=True)
+            slide_list = []
+            counter = 0
+            for slide in (f / 'BLOCKS').iterdir():
+                if Path(slide).is_dir(): 
+                    if slide.suffix != '.zarr':
+                        slide_list.append(slide)
+
+            print(len(slide_list))
+
+            with tqdm(total=len(slide_list)) as pbar:
+                for slide in slide_list:
+                    output_path = fe_path / Path(str(slide.stem) + '.zarr')
+                    # print('slide: ', slide)
+
+                    # run every slide 5 times for augments
+                    for n in range(6):
+
+                        if n != 5:
+                            output_path = fe_path / Path(str(slide.stem) + f'_aug{n}.zarr')
+                        else: 
+                            output_path = fe_path / Path(str(slide.stem) + '.zarr')
+                        if output_path.is_dir():
+                                # print(output_path, ' skipped.')
+                            pbar.update(1)
+                            continue
+                            
+                        else: 
+                            counter += 1
+                            print(output_path)
+                            pbar.update(1)
+            print(counter)
+
+                  
\ No newline at end of file
diff --git a/code/datasets/monai_loader.py b/code/datasets/monai_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bbe8033e7e15c82f1ffcafce7feddc7cc1ce770
--- /dev/null
+++ b/code/datasets/monai_loader.py
@@ -0,0 +1,179 @@
+import numpy as np
+import collections.abc
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+
+from monai.config import KeysCollection
+from monai.data import Dataset, load_decathlon_datalist, PersistentDataset
+from monai.data.wsi_reader import WSIReader, CuCIMWSIReader
+# from monai.data.image_reader import CuCIMWSIReader
+from monai.networks.nets import milmodel
+from monai.transforms import (
+    Compose,
+    GridPatchd,
+    LoadImaged,
+    LoadImage,
+    MapTransform,
+    RandFlipd,
+    RandGridPatchd,
+    RandRotate90d,
+    ScaleIntensityRanged,
+    SplitDimd,
+    ToTensord,
+)
+from sklearn.metrics import cohen_kappa_score
+from torch.cuda.amp import GradScaler, autocast
+from torch.utils.data.dataloader import default_collate
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.tensorboard import  SummaryWriter
+import json
+from pathlib import Path
+import time
+
+class LabelEncodeIntegerGraded(MapTransform):
+    """
+    Convert an integer label to encoded array representation of length num_classes,
+    with 1 filled in up to label index, and 0 otherwise. For example for num_classes=5,
+    embedding of 2 -> (1,1,0,0,0)
+    Args:
+        num_classes: the number of classes to convert to encoded format.
+        keys: keys of the corresponding items to be transformed. Defaults to ``'label'``.
+        allow_missing_keys: don't raise exception if key is missing.
+    """
+
+    def __init__(
+        self,
+        num_classes: int,
+        keys: KeysCollection = "label",
+        allow_missing_keys: bool = False,
+    ):
+        super().__init__(keys, allow_missing_keys)
+        self.num_classes = num_classes
+
+    def __call__(self, data):
+
+        d = dict(data)
+        for key in self.keys:
+            label = int(d[key])
+
+            lz = np.zeros(self.num_classes, dtype=np.float32)
+            lz[:label] = 1.0
+            # alternative oneliner lz=(np.arange(self.num_classes)<int(label)).astype(np.float32) #same oneliner
+            d[key] = lz
+
+        return d
+
+def list_data_collate(batch: collections.abc.Sequence):
+    # print(f"{i} = {item['image'].shape=} >> {item['image'].keys=}")
+    for i, item in enumerate(batch):
+        data = item[0]
+        data["image"] = torch.stack([ix["image"] for ix in item], dim=0)
+        # data["patch_location"] = torch.stack([ix["patch_location"] for ix in item], dim=0)
+        batch[i] = data
+    return default_collate(batch)
+
+
+
+
+
+if __name__ == '__main__':
+
+    num_classes = 2
+    batch_size=1
+    tile_size = 224
+    tile_count = 1000
+    home = Path.cwd().parts[1]
+    data_root = f'/{home}/ylan/DeepGraft/'
+    # labels = [0]
+    # data_root = f'/{home}/public/DeepGraft/Aachen_Biopsy_Slides_Extended'
+    data = {"training": [{
+        "image": 'Aachen_KiBiDatabase_KiBiAcRCIQ360_01_018_PAS.svs', 
+        "label": 0
+        }, {
+        "image": 'Aachen_KiBiDatabase_KiBiAcRCIQ360_01_018_PAS.svs', 
+        "label": 0
+        }, {
+        "image": 'Aachen_KiBiDatabase_KiBiAcRCIQ360_01_018_PAS.svs', 
+        "label": 0
+        }],
+        "validation": [{
+        "image": 'Aachen_KiBiDatabase_KiBiAcRCIQ360_01_018_PAS.svs', 
+        "label": 0
+        }]
+    }
+    with open('monai_test.json', 'w') as jf:
+        json.dump(data, jf)
+    json_data_path = f'/{home}/ylan/DeepGraft/training_tables/dg_decathlon_PAS_HE_Jones_norm_rest.json'
+
+    training_list = load_decathlon_datalist(
+        data_list_file_path=json_data_path,
+        data_list_key="training",
+        base_dir=data_root,
+    )
+
+    train_transform = Compose(
+        [
+            LoadImaged(keys=["image"], reader=WSIReader, backend="cucim", dtype=np.uint8, level=0, image_only=True, num_workers=8),
+            LabelEncodeIntegerGraded(keys=["label"], num_classes=num_classes),
+            RandGridPatchd(
+                keys=["image"],
+                patch_size=(tile_size, tile_size),
+                threshold=0.999 * 3 * 255 * tile_size * tile_size,
+                num_patches=None,
+                sort_fn="min",
+                pad_mode=None,
+                constant_values=255,
+            ),
+            SplitDimd(keys=["image"], dim=0, keepdim=False, list_output=True),
+            RandFlipd(keys=["image"], spatial_axis=0, prob=0.5),
+            RandFlipd(keys=["image"], spatial_axis=1, prob=0.5),
+            RandRotate90d(keys=["image"], prob=0.5),
+            ScaleIntensityRanged(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)),
+            ToTensord(keys=["image", "label"]),
+        ]
+    )
+    train_data_list = data['training']
+    # dataset_train = Dataset(data=training_list)
+    dataset_train = Dataset(data=training_list, transform=train_transform)
+    # persistent_dataset = PersistentDataset(data=training_list, transform=train_transform, cache_dir='/home/ylan/workspace/test')
+    
+
+    train_loader = torch.utils.data.DataLoader(
+        dataset_train,
+        batch_size=batch_size,
+        shuffle=True,
+        num_workers=1,
+        pin_memory=True,
+        sampler=None,
+        collate_fn=list_data_collate,
+    )
+
+    print(len(train_loader))
+    start = time.time()
+    count = 0
+
+    # train_transform = LoadImage(reader=WSIReader, backend='openslide', level=3)
+    # filename = '/home/ylan/DeepGraft/DEEPGRAFT_RU/T19-01474_I1_HE 10_959004.ndpi'
+    # X = train_transform(filename)
+    # print(X)
+    # img, meta = reader.read(data='/home/ylan/DeepGraft/DEEPGRAFT_RU/T19-01474_I1_HE 10_959004.ndpi')
+
+    # print(meta)
+
+    for idx, batch_data in enumerate(train_loader):
+        # print(batch_data)
+        if count > 10: 
+            break
+        data, target = batch_data["image"], batch_data["label"]
+        print(target)
+        count += 1
+    end = time.time()
+    print('Time: ', end-start)
+
+    # image_reader = WSIReader(backend='cucim')
+    # for i in training_list:
+    #     # print(i)
+    #     wsi = image_reader.read(i['image'])
+    #     img_data, meta_data = image_reader.get_data(wsi)
+    #     print(meta_data)
\ No newline at end of file
diff --git a/code/datasets/simple_jpg_dataloader.py b/code/datasets/simple_jpg_dataloader.py
index 332920a6b269ff1f4593e2f60e50ffa3109a4d12..c5e349f3b7f6426140fe8c9b026d3dc8932853bf 100644
--- a/code/datasets/simple_jpg_dataloader.py
+++ b/code/datasets/simple_jpg_dataloader.py
@@ -20,8 +20,8 @@ from imgaug import augmenters as iaa
 from torchsampler import ImbalancedDatasetSampler
 
 
-class FeatureBagLoader(data_utils.Dataset):
-    def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=100, max_bag_size=1000):
+class JPGBagLoader(data_utils.Dataset):
+    def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=100, max_bag_size=1000, cache=False):
         super().__init__()
 
         self.data_info = []
@@ -35,7 +35,7 @@ class FeatureBagLoader(data_utils.Dataset):
         self.label_path = label_path
         self.n_classes = n_classes
         self.max_bag_size = max_bag_size
-        self.min_bag_size = 120
+        self.min_bag_size = 50
         self.empty_slides = []
         self.corrupt_slides = []
         self.cache = True
@@ -222,7 +222,7 @@ class FeatureBagLoader(data_utils.Dataset):
         if self.cache:
             label = self.labels[index]
             wsi = self.features[index]
-            label = Variable(Tensor(label))
+            label = int(label)
             wsi_name = self.wsi_names[index]
             name_batch = self.name_batches[index]
             patient = self.patients[index]
@@ -231,13 +231,14 @@ class FeatureBagLoader(data_utils.Dataset):
         else:
             if self.mode=='train':
                 batch, label, (wsi_name, name_batch, patient) = self.get_data(self.files[index])
-                label = Variable(Tensor(label))
+                # label = Variable(Tensor(label))
+
                 # wsi = Variable(Tensor(wsi_batch))
                 out_batch = []
                 seq_img_d = self.train_transforms.to_deterministic()
                 for img in batch: 
                     img = img.numpy().astype(np.uint8)
-                    img = seq_img_d.augment_image(img)
+                    # img = seq_img_d.augment_image(img)
                     img = self.val_transforms(img.copy())
                     out_batch.append(img)
                 out_batch = torch.stack(out_batch)
@@ -278,7 +279,7 @@ if __name__ == '__main__':
 
     n_classes = 2
 
-    dataset = FeatureBagLoader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes)
+    dataset = JPGBagLoader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes)
 
     # print(dataset.get_labels(0))
     a = int(len(dataset)* 0.8)
@@ -311,7 +312,7 @@ if __name__ == '__main__':
         bag, label, (name, batch_names, patient) = item
         # print(bag.shape)
         # print(len(batch_names))
-        
+        print(label)
         bag = bag.squeeze(0).float().to(device)
         label = label.to(device)
         with torch.cuda.amp.autocast():
diff --git a/code/datasets/test_normalization.ipynb b/code/datasets/test_normalization.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..c9e0a6bb756bb2f89dacf2369ba4d5f7a96b8937
--- /dev/null
+++ b/code/datasets/test_normalization.ipynb
@@ -0,0 +1,195 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# from simple_jpg_dataloader import JPGBagLoader\n",
+    "import torch\n",
+    "from torch.utils.data import random_split, DataLoader\n",
+    "from pathlib import Path\n",
+    "import numpy as np\n",
+    "import random\n",
+    "from torchvision.transforms import transforms\n",
+    "import matplotlib.pyplot as plt\n",
+    "from PIL import Image\n",
+    "import cv2\n",
+    "import json"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "home = Path.cwd().parts[1]\n",
+    "label_path = f'/{home}/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest.json'\n",
+    "data_root = f'/{home}/ylan/data/DeepGraft/224_128uM_annotated'\n",
+    "n_classes = 2"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "val_transforms = transforms.Compose([\n",
+    "    # \n",
+    "    transforms.ToTensor(),\n",
+    "    transforms.Normalize(\n",
+    "        mean=[0.485, 0.456, 0.406],\n",
+    "        std=[0.229, 0.224, 0.225],\n",
+    "    ),\n",
+    "    # RangeNormalization(),\n",
+    "])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def visualize(cohort):\n",
+    "\n",
+    "    cohort_path = Path(data_root) / cohort / 'BLOCKS'\n",
+    "    print(cohort_path)\n",
+    "    cohort_slides = list(Path(cohort_path).iterdir())\n",
+    "    random_idx = random.sample(range(0, len(cohort_slides)), 10)\n",
+    "    random_slides = [cohort_slides[i] for i in random_idx]\n",
+    "    print(random_slides)\n",
+    "\n",
+    "    fig = plt.figure(figsize=(100,100))\n",
+    "    columns = 10\n",
+    "    rows = 10\n",
+    "\n",
+    "    for i, slide in enumerate(random_slides):\n",
+    "        tile_list = list(slide.iterdir())\n",
+    "        if len(tile_list) < 10:\n",
+    "            # continue\n",
+    "            tile_list = list(cohort_slides[random.randint(0,len(cohort_slides))].iterdir())\n",
+    "        random_idx = random.sample(range(0, len(tile_list)), 10)\n",
+    "        for j, tile_path in enumerate([tile_list[i] for i in random_idx]):\n",
+    "            img = np.asarray(Image.open(tile_path)).astype(np.uint8)\n",
+    "            img = img.astype(np.uint8)\n",
+    "            img = val_transforms(img.copy())\n",
+    "            img = ((img-img.min())/(img.max()-img.min()))*255\n",
+    "            img = img.numpy().astype(np.uint8).transpose(1,2,0)\n",
+    "            img = Image.fromarray(img)\n",
+    "            img = img.convert('RGB')\n",
+    "            # print((i+1)*rows+j)\n",
+    "            fig.add_subplot(rows, columns, (i)*rows+(j+1))\n",
+    "            # fig.add_subplot(rows, columns, (i+1)*rows+(j+1))\n",
+    "            plt.imshow(img)\n",
+    "    plt.show\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def hexencode(rgb):\n",
+    "    r=rgb[0]\n",
+    "    g=rgb[1]\n",
+    "    b=rgb[2]\n",
+    "    return '#%02x%02x%02x' % (r,g,b)\n",
+    "\n",
+    "def normalize(slides):\n",
+    "\n",
+    "    # cohort_path = Path(data_root) / 'debug' / slide\n",
+    "    # # print(cohort_path)\n",
+    "    # cohort_slides = list(Path(cohort_path).iterdir())\n",
+    "    # random_idx = random.sample(range(0, len(cohort_slides)), 5)\n",
+    "    # random_slides = [cohort_slides[i] for i in random/_idx]\n",
+    "    # print(random_slides)\n",
+    "\n",
+    "    fig = plt.figure(figsize=(100,100))\n",
+    "    columns = 10\n",
+    "    rows = 10\n",
+    "\n",
+    "    for i, slide in enumerate(slides):\n",
+    "        slide_path = Path(data_root) / 'debug' / slide\n",
+    "        tile_list = list(slide_path.iterdir())\n",
+    "        if len(tile_list) < 10:\n",
+    "            # continue\n",
+    "            tile_list = list(cohort_slides[random.randint(0,len(cohort_slides))].iterdir())\n",
+    "        random_idx = random.sample(range(0, len(tile_list)), 5)\n",
+    "        for j, tile_path in enumerate([tile_list[i] for i in random_idx]):\n",
+    "            # print(tile_path)\n",
+    "            img = np.asarray(Image.open(tile_path)).astype(np.uint8)\n",
+    "            img = img.astype(np.uint8)\n",
+    "            img = val_transforms(img.copy())\n",
+    "            img = ((img-img.min())/(img.max()-img.min()))*255\n",
+    "            img_np = img.numpy().astype(np.uint8).transpose(1,2,0)\n",
+    "            img = Image.fromarray(img_np)\n",
+    "            img = img.convert('RGB')\n",
+    "            # print((i+1)*rows+j)\n",
+    "            # fig.add_subplot(rows, columns, (i*2)*rows+(j+1))\n",
+    "            # # fig.add_subplot(rows, columns, (i+1)*rows+(j+1))\n",
+    "            # plt.imshow(img)\n",
+    "\n",
+    "            color = ('b','g','r')\n",
+    "            fig.add_subplot(rows, columns, (i*2)*rows+(j+1))\n",
+    "            for i,col in enumerate(color):\n",
+    "                histr = cv2.calcHist([img_np],[i],None,[256],[0,256])\n",
+    "                plt.plot(histr,color = col)\n",
+    "                plt.xlim([0,256])\n",
+    "            plt.show\n",
+    "            # plt.imshow(img)\n",
+    "\n",
+    "        \n",
+    "    plt.show\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "json_path = f'/{home}/ylan/training_tables/dg_split_PAS_HE_Jones_norm_rest.json'\n",
+    "with open(json_path, 'r') as jf:\n",
+    "    split_dict = json.read(jf)\n",
+    "\n",
+    "print(split_dict)\n",
+    "\n",
+    "slides = ['DEEPGRAFT_RA/RA0002_PASD_jkers_PASD_20180829_142406', 'DEEPGRAFT_RU/RU0001_PASD_jke_PASD_20200129_122805_BIG', 'Aachen_Biopsy_Slides/Aachen_KiBiDatabase_KiBiAcALSZ690_01_004_PAS']\n",
+    "\n",
+    "# normalize(slides)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3.8.8 ('torch')",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.13"
+  },
+  "orig_nbformat": 4,
+  "vscode": {
+   "interpreter": {
+    "hash": "7b7fb95db5714bbf59d6a04f6057e8fa5746fef9d16f5c42f2fdbc713170171a"
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/code/datasets/zarr_feature_dataloader.py b/code/datasets/zarr_feature_dataloader.py
index b244b87c83a128d9b40b604fef8df647690d482b..9387e13952ab180c173a556cd640909454651a85 100644
--- a/code/datasets/zarr_feature_dataloader.py
+++ b/code/datasets/zarr_feature_dataloader.py
@@ -22,11 +22,11 @@ from PIL import Image
 
 
 class ZarrFeatureBagLoader(data.Dataset):
-    def __init__(self, file_path, label_path, mode, n_classes, cache=False, data_cache_size=100, max_bag_size=1000):
+    def __init__(self, file_path, label_path, mode, n_classes, cache=False, data_cache_size=50, max_bag_size=1000):
         super().__init__()
 
         self.data_info = []
-        self.data_cache = {}
+        self.data_cache = []
         self.slideLabelDict = {}
         self.files = []
         self.data_cache_size = data_cache_size
@@ -39,17 +39,23 @@ class ZarrFeatureBagLoader(data.Dataset):
         self.min_bag_size = 120
         self.empty_slides = []
         self.corrupt_slides = []
-        self.cache = True
-        
+        self.cache = cache
+        self.drop_rate=0.1
+        self.cache=True
+        print('mode: ', self.mode)
         # read labels and slide_path from csv
         with open(self.label_path, 'r') as f:
-            temp_slide_label_dict = json.load(f)[mode]
+            temp_slide_label_dict = json.load(f)[self.mode]
             # print(len(temp_slide_label_dict))
             for (x, y) in temp_slide_label_dict:
                 x = Path(x).stem
                 # x_complete_path = Path(self.file_path)/Path(x)
                 for cohort in Path(self.file_path).iterdir():
-                    x_complete_path = Path(self.file_path) / cohort / 'FEATURES_RETCCL' / (str(x) + '.zarr')
+                    if self.mode == 'test':
+                        x_complete_path = Path(self.file_path) / cohort / 'FEATURES_RETCCL_GAN' / (str(x) + '.zarr')
+                    else:
+                        x_complete_path = Path(self.file_path) / cohort / 'FEATURES_RETCCL' / (str(x) + '.zarr')
+                    print(x_complete_path)
                     if x_complete_path.is_dir():
                         # if len(list(x_complete_path.iterdir())) > self.min_bag_size:
                         # # print(x_complete_path)
@@ -66,26 +72,65 @@ class ZarrFeatureBagLoader(data.Dataset):
         self.wsi_names = []
         self.name_batches = []
         self.patients = []
-        if self.cache:
-            for t in tqdm(self.files):
-                # zarr_t = str(t) + '.zarr'
-                batch, label, (wsi_name, name_batch, patient) = self.get_data(t)
-
-                self.labels.append(label)
-                self.feature_bags.append(batch)
-                self.wsi_names.append(wsi_name)
-                self.name_batches.append(name_batch)
-                self.patients.append(patient)
-
-    def get_data(self, file_path, drop_rate=0.1):
-        
-        batch_names=[] #add function for name_batch read out
+        for t in tqdm(self.files):
+            self._add_data_infos(t, cache=cache)
+
+
+        print('data_cache_size: ', self.data_cache_size)
+        print('data_info: ', len(self.data_info))
+        # if self.cache:
+        #     print('Loading data into cache.')
+        #     for t in tqdm(self.files):
+        #         # zarr_t = str(t) + '.zarr'
+        #         batch, label, (wsi_name, name_batch, patient) = self.get_data(t)
+
+        #         self.labels.append(label)
+        #         self.feature_bags.append(batch)
+        #         self.wsi_names.append(wsi_name)
+        #         self.name_batches.append(name_batch)
+        #         self.patients.append(patient)
+        # else: 
+            
 
+    def _add_data_infos(self, file_path, cache):
+
+        # if cache:
         wsi_name = Path(file_path).stem
-        if wsi_name in self.slideLabelDict:
-            label = self.slideLabelDict[wsi_name]
-            
-            patient = self.slide_patient_dict[wsi_name]
+        # if wsi_name in self.slideLabelDict:
+        label = self.slideLabelDict[wsi_name]
+        patient = self.slide_patient_dict[wsi_name]
+        idx = -1
+        self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'patient': patient, 'cache_idx': idx})
+
+    def get_data(self, i):
+
+        fp = self.data_info[i]['data_path']
+        idx = self.data_info[i]['cache_idx']
+        if idx == -1:
+
+        # if fp not in self.data_cache:
+            self._load_data(fp)
+        
+        
+        cache_idx = self.data_info[i]['cache_idx']
+        label = self.data_info[i]['label']
+        name = self.data_info[i]['name']
+        patient = self.data_info[i]['patient']
+
+        return self.data_cache[cache_idx], label, name, patient
+        # return self.data_cache[fp][cache_idx], label, name, patient
+        
+
+
+    def _load_data(self, file_path):
+        
+
+        # batch_names=[] #add function for name_batch read out
+        # wsi_name = Path(file_path).stem
+        # if wsi_name in self.slideLabelDict:
+        #     label = self.slideLabelDict[wsi_name]
+        #     patient = self.slide_patient_dict[wsi_name]
+
         z = zarr.open(file_path, 'r')
         np_bag = np.array(z['data'][:])
         # np_bag = np.array(zarr.open(file_path, 'r')).astype(np.uint8)
@@ -96,15 +141,45 @@ class ZarrFeatureBagLoader(data.Dataset):
         bag_size = wsi_bag.shape[0]
         
         # random drop 
-        bag_idxs = torch.randperm(bag_size)[:int(bag_size*(1-drop_rate))]
+        bag_idxs = torch.randperm(bag_size)[:int(bag_size*(1-self.drop_rate))]
         wsi_bag = wsi_bag[bag_idxs, :]
         batch_coords = batch_coords[bag_idxs]
+
+        idx = self._add_to_cache((wsi_bag, batch_coords), file_path)
+        file_idx = next(i for i, v in enumerate(self.data_info) if v['data_path'] == file_path)
+        # print('file_idx: ', file_idx)
+        # print('idx: ', idx)
+        self.data_info[file_idx]['cache_idx'] = idx
         # print(wsi_bag.shape)
         # name_samples = [batch_names[i] for i in bag_idxs]
-        return wsi_bag, label, (wsi_name, batch_coords, patient)
+        # return wsi_bag, label, (wsi_name, batch_coords, patient)
+        
+        if len(self.data_cache) > self.data_cache_size:
+            # removal_keys = list(self.data_cache)
+            # removal_keys.remove(file_path)
+
+            self.data_cache.pop(idx)
+
+            self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'name': di['name'], 'patient':di['patient'], 'cache_idx':-1} if di['cache_idx'] == idx else di for di in self.data_info]
+        
+
+
+    def _add_to_cache(self, data, data_path):
+
+
+        # if data_path not in self.data_cache:
+        #     self.data_cache[data_path] = [data]
+        # else:
+        #     self.data_cache[data_path].append(data)
+        self.data_cache.append(data)
+        # print(len(self.data_cache))
+        # return len(self.data_cache)
+        return len(self.data_cache) - 1
+
     
     def get_labels(self, indices):
-        return [self.labels[i] for i in indices]
+        # return [self.labels[i] for i in indices]
+        return [self.data_info[i]['label'] for i in indices]
 
 
     def to_fixed_size_bag(self, bag, names, bag_size: int = 512):
@@ -136,36 +211,17 @@ class ZarrFeatureBagLoader(data.Dataset):
         return bag_samples, name_samples
 
     def __len__(self):
-        return len(self.files)
+        # return len(self.files)
+        return len(self.data_info)
 
     def __getitem__(self, index):
 
-        if self.cache:
-            label = self.labels[index]
-            wsi = self.feature_bags[index]
-            # label = Variable(Tensor(label))
-            label = torch.as_tensor(label)
-            label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
-            wsi_name = self.wsi_names[index]
-            name_batch = self.name_batches[index]
-            patient = self.patients[index]
-
-            #random dropout
-            #shuffle
-
-            # feats = Variable(Tensor(feats))
-            return wsi, label, (wsi_name, name_batch, patient)
-        else:
-            t = self.files[index]
-            batch, label, (wsi_name, name_batch, patient) = self.get_data(t)
-
-                # self.labels.append(label)
-                # self.feature_bags.append(batch)
-                # self.wsi_names.append(wsi_name)
-                # self.name_batches.append(name_batch)
-                # self.patients.append(patient)
-
-            return batch, label, (wsi_name, name_batch, patient)
+        (wsi, batch_coords), label, wsi_name, patient = self.get_data(index)
+
+        label = torch.as_tensor(label)
+        label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+
+        return wsi, label, (wsi_name, batch_coords, patient)
 
 if __name__ == '__main__':
     
@@ -182,14 +238,14 @@ if __name__ == '__main__':
     data_root = f'/{home}/ylan/data/DeepGraft/224_128um_v2'
     # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
     # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
-    label_path = f'/{home}/ylan/DeepGraft/training_tables/split_debug.json'
-    # label_path = f'/{home}/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_norm_rest.json'
+    # label_path = f'/{home}/ylan/DeepGraft/training_tables/split_debug.json'
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest.json'
     output_dir = f'/{data_root}/debug/augments'
     os.makedirs(output_dir, exist_ok=True)
 
     n_classes = 2
 
-    dataset = ZarrFeatureBagLoader(data_root, label_path=label_path, mode='train', cache=True, n_classes=n_classes)
+    dataset = ZarrFeatureBagLoader(data_root, label_path=label_path, mode='train', cache=False, data_cache_size=3000, n_classes=n_classes)
 
     # print(dataset.get_labels(0))
     a = int(len(dataset)* 0.8)
@@ -200,7 +256,7 @@ if __name__ == '__main__':
     # b = int(len(dataset) - a)
     # train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b])
     # dl = FastTensorDataLoader(dataset, batch_size=1, shuffle=False)
-    dl = DataLoader(train_data, batch_size=1, num_workers=8, sampler=ImbalancedDatasetSampler(train_data), pin_memory=True)
+    dl = DataLoader(train_data, batch_size=1, num_workers=8)#, pin_memory=True , sampler=ImbalancedDatasetSampler(train_data)
     # print(len(dl))
     # dl = DataLoader(dataset, batch_size=1, sampler=ImbalancedDatasetSampler(dataset), num_workers=5)
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -210,29 +266,35 @@ if __name__ == '__main__':
     # for param in model_ft.parameters():
     #     param.requires_grad = False
     # model_ft.to(device)
-    model = TransMIL(n_classes=n_classes).to(device)
+    # model = TransMIL(n_classes=n_classes).to(device)
     
     c = 0
     label_count = [0] *n_classes
     epochs = 1
-    # print(len(dl))
+    print(len(dl))
     # start = time.time()
+
+    count = 0
     for i in range(epochs):
         start = time.time()
         for item in tqdm(dl): 
-
             # if c >= 10:
             #     break
             bag, label, (name, batch_names, patient) = item
             # print(bag.shape)
             # print(len(batch_names))
-            print(label)
-            print(batch_names)
+            # print(label)
+            # print(batch_names)
             bag = bag.float().to(device)
+            # print(bag)
+            # print(name)
+            # bag = bag.float().to(device)
             # print(bag.shape)
             # label = label.to(device)
-            with torch.cuda.amp.autocast():
-                output = model(bag)
-            # c += 1
+            # with torch.cuda.amp.autocast():
+            #     output = model(bag)
+            count += 1
+            
         end = time.time()
-        print('Bag Time: ', end-start)
\ No newline at end of file
+        print('Bag Time: ', end-start)
+        print(count)
\ No newline at end of file
diff --git a/code/datasets/zarr_feature_dataloader_simple.py b/code/datasets/zarr_feature_dataloader_simple.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9313c9447efdf83a5ff89707d2dcacfab2c6021
--- /dev/null
+++ b/code/datasets/zarr_feature_dataloader_simple.py
@@ -0,0 +1,255 @@
+import pandas as pd
+
+import numpy as np
+import torch
+from torch import Tensor
+from torch.autograd import Variable
+from torch.nn.functional import one_hot
+from torch.utils import data
+from torch.utils.data import random_split, DataLoader
+from torchsampler import ImbalancedDatasetSampler
+from torchvision import datasets, transforms
+import pandas as pd
+from sklearn.utils import shuffle
+from pathlib import Path
+from tqdm import tqdm
+import zarr
+import json
+import cv2
+from PIL import Image
+# from models import TransMIL
+
+
+
+class ZarrFeatureBagLoader(data.Dataset):
+    def __init__(self, file_path, label_path, mode, n_classes, cache=False, data_cache_size=5000, max_bag_size=1000):
+        super().__init__()
+
+        self.data_info = []
+        self.data_cache = {}
+        self.slideLabelDict = {}
+        self.files = []
+        self.data_cache_size = data_cache_size
+        self.mode = mode
+        self.file_path = file_path
+        # self.csv_path = csv_path
+        self.label_path = label_path
+        self.n_classes = n_classes
+        self.max_bag_size = max_bag_size
+        self.drop_rate = 0.1
+        # self.min_bag_size = 120
+        self.empty_slides = []
+        self.corrupt_slides = []
+        self.cache = cache
+        
+        # read labels and slide_path from csv
+        with open(self.label_path, 'r') as f:
+            temp_slide_label_dict = json.load(f)[mode]
+            # print(len(temp_slide_label_dict))
+            for (x, y) in temp_slide_label_dict:
+                x = Path(x).stem
+                # x_complete_path = Path(self.file_path)/Path(x)
+                for cohort in Path(self.file_path).iterdir():
+                    # x_complete_path = Path(self.file_path) / cohort / 'FEATURES_RETCCL' / (str(x) + '.zarr')
+                    if self.mode == 'test': #set to test if using GAN output
+                        x_path_list = [Path(self.file_path) / cohort / 'FEATURES_RETCCL_2048' / (str(x) + '.zarr')]
+                    else:
+                        x_path_list = [Path(self.file_path) / cohort / 'FEATURES_RETCCL_2048' / (str(x) + '.zarr')]
+                        for i in range(5):
+                            x_path_list.append(Path(self.file_path) / cohort / 'FEATURES_RETCCL' / (str(x) + f'_aug{i}.zarr'))
+                    # print(x_complete_path)
+                    for x_path in x_path_list:
+                        if x_path.is_dir():
+                            # if len(list(x_complete_path.iterdir())) > self.min_bag_size:
+                            # # print(x_complete_path)
+                            self.slideLabelDict[x] = y
+                            self.files.append(x_path)
+        
+        # print(self.files)
+        home = Path.cwd().parts[1]
+        self.slide_patient_dict_path = f'/{home}/ylan/DeepGraft/training_tables/slide_patient_dict.json'
+        with open(self.slide_patient_dict_path, 'r') as f:
+            self.slide_patient_dict = json.load(f)
+
+        self.feature_bags = []
+        self.labels = []
+        self.wsi_names = []
+        self.name_batches = []
+        self.patients = []
+        if self.cache:
+            for t in tqdm(self.files):
+                # zarr_t = str(t) + '.zarr'
+                batch, label, (wsi_name, name_batch, patient) = self.get_data(t)
+
+                self.labels.append(label)
+                self.feature_bags.append(batch)
+                self.wsi_names.append(wsi_name)
+                self.name_batches.append(name_batch)
+                self.patients.append(patient)
+
+    def get_data(self, file_path):
+        
+        batch_names=[] #add function for name_batch read out
+
+        wsi_name = Path(file_path).stem
+        if wsi_name.split('_')[-1][:3] == 'aug':
+            wsi_name = '_'.join(wsi_name.split('_')[:-1])
+        # if wsi_name in self.slideLabelDict:
+        label = self.slideLabelDict[wsi_name]
+        label = torch.as_tensor(label)
+        label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+        patient = self.slide_patient_dict[wsi_name]
+        z = zarr.open(file_path, 'r')
+        np_bag = np.array(z['data'][:])
+        # np_bag = np.array(zarr.open(file_path, 'r')).astype(np.uint8)
+        wsi_bag = torch.from_numpy(np_bag)
+        batch_coords = torch.from_numpy(np.array(z['coords'][:]))
+
+        # print(wsi_bag.shape)
+        bag_size = wsi_bag.shape[0]
+        
+        # random drop 
+        
+        bag_idxs = torch.randperm(bag_size)[:int(self.max_bag_size*(1-self.drop_rate))]
+        wsi_bag = wsi_bag[bag_idxs, :]
+        batch_coords = batch_coords[bag_idxs]
+        # print(wsi_bag.shape)
+        # name_samples = [batch_names[i] for i in bag_idxs]
+        return wsi_bag, label, (wsi_name, batch_coords, patient)
+    
+    def get_labels(self, indices):
+        return [self.labels[i] for i in indices]
+
+
+    def to_fixed_size_bag(self, bag, names, bag_size: int = 512):
+
+        #duplicate bag instances unitl 
+
+        bag_idxs = torch.randperm(bag.shape[0])[:bag_size]
+        bag_samples = bag[bag_idxs]
+        name_samples = [names[i] for i in bag_idxs]
+        # bag_sample_names = [bag_names[i] for i in bag_idxs]
+        # q, r  = divmod(bag_size, bag_samples.shape[0])
+        # if q > 0:
+        #     bag_samples = torch.cat([bag_samples]*q, 0)
+
+        # self_padded = torch.cat([bag_samples, bag_samples[:r,:, :, :]])
+
+        # zero-pad if we don't have enough samples
+        # zero_padded = torch.cat((bag_samples,
+        #                         torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+
+        return bag_samples, name_samples, min(bag_size, len(bag))
+
+    def data_dropout(self, bag, batch_names, drop_rate):
+        # bag_size = self.max_bag_size
+        # bag_size = bag.shape[0]
+        bag_idxs = torch.randperm(self.max_bag_size)[:int(bag_size*(1-drop_rate))]
+        bag_samples = bag[bag_idxs]
+        name_samples = [batch_names[i] for i in bag_idxs]
+
+        return bag_samples, name_samples
+
+    def __len__(self):
+        return len(self.files)
+
+    def __getitem__(self, index):
+
+        if self.cache:
+            label = self.labels[index]
+            wsi = self.feature_bags[index]
+            # label = Variable(Tensor(label))
+            # label = torch.as_tensor(label)
+            # label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+            wsi_name = self.wsi_names[index]
+            name_batch = self.name_batches[index]
+            patient = self.patients[index]
+
+            #random dropout
+            #shuffle
+
+            # feats = Variable(Tensor(feats))
+            return wsi, label, (wsi_name, name_batch, patient)
+        else:
+            t = self.files[index]
+            batch, label, (wsi_name, name_batch, patient) = self.get_data(t)
+            # label = torch.as_tensor(label)
+            # label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+                # self.labels.append(label)
+                # self.feature_bags.append(batch)
+                # self.wsi_names.append(wsi_name)
+                # self.name_batches.append(name_batch)
+                # self.patients.append(patient)
+
+            return batch, label, (wsi_name, name_batch, patient)
+
+if __name__ == '__main__':
+    
+    from pathlib import Path
+    import os
+    import time
+    # from fast_tensor_dl import FastTensorDataLoader
+    # from custom_resnet50 import resnet50_baseline
+    
+    
+
+    home = Path.cwd().parts[1]
+    train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
+    data_root = f'/{home}/ylan/data/DeepGraft/224_128um_v2'
+    # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
+    # label_path = f'/{home}/ylan/DeepGraft/training_tables/split_debug.json'
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_norm_rest.json'
+    output_dir = f'/{data_root}/debug/augments'
+    os.makedirs(output_dir, exist_ok=True)
+
+    n_classes = 2
+
+    dataset = ZarrFeatureBagLoader(data_root, label_path=label_path, mode='train', cache=True, n_classes=n_classes)
+
+    # print(dataset.get_labels(0))
+    a = int(len(dataset)* 0.8)
+    b = int(len(dataset) - a)
+    train_data, valid_data = random_split(dataset, [a, b])
+    # print(dataset.dataset)
+    # a = int(len(dataset)* 0.8)
+    # b = int(len(dataset) - a)
+    # train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b])
+    # dl = FastTensorDataLoader(dataset, batch_size=1, shuffle=False)
+    dl = DataLoader(train_data, batch_size=1, num_workers=8, pin_memory=True)
+    # print(len(dl))
+    # dl = DataLoader(dataset, batch_size=1, sampler=ImbalancedDatasetSampler(dataset), num_workers=5)
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    scaler = torch.cuda.amp.GradScaler()
+
+    # model_ft = resnet50_baseline(pretrained=True)
+    # for param in model_ft.parameters():
+    #     param.requires_grad = False
+    # model_ft.to(device)
+    # model = TransMIL(n_classes=n_classes).to(device)
+    
+    c = 0
+    label_count = [0] *n_classes
+    epochs = 1
+    # print(len(dl))
+    # start = time.time()
+    for i in range(epochs):
+        start = time.time()
+        for item in tqdm(dl): 
+
+            # if c >= 10:
+            #     break
+            bag, label, (name, batch_coords, patient) = item
+            # print(bag.shape)
+            # print(len(batch_names))
+            # print(label)
+            # print(batch_coords)
+            print(name)
+            bag = bag.float().to(device)
+            # print(bag.shape)
+            # label = label.to(device)
+            # with torch.cuda.amp.autocast():
+            #     output = model(bag)
+            # c += 1
+        end = time.time()
+        print('Bag Time: ', end-start)
\ No newline at end of file
diff --git a/code/models/AttMIL.py b/code/models/AttMIL.py
index 048fb0fc782aa856c1e7c6edcc850e0e3cb39552..89ff5d5aae5461e72181f416cdba0b7f13f93125 100644
--- a/code/models/AttMIL.py
+++ b/code/models/AttMIL.py
@@ -13,7 +13,7 @@ from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
 
 
 class AttMIL(nn.Module): #gated attention
-    def __init__(self, n_classes, features=512):
+    def __init__(self, n_classes, features=1024):
         super(AttMIL, self).__init__()
         self.L = features
         self.D = 128
diff --git a/code/models/TransMIL.py b/code/models/TransMIL.py
index ddc126286835639ded83fd68054401c500237061..01a0fa7a9ec3ebfa0c8655eb6f68bafa79fc6f01 100755
--- a/code/models/TransMIL.py
+++ b/code/models/TransMIL.py
@@ -32,11 +32,12 @@ class TransLayer(nn.Module):
         )
 
     def forward(self, x):
-        out, attn = self.attn(self.norm(x), return_attn=True)
+        out= self.attn(self.norm(x))
+        # out, attn = self.attn(self.norm(x))
         x = x + out
         # x = x + self.attn(self.norm(x))
 
-        return x, attn
+        return x
 
 
 class PPEG(nn.Module):
@@ -59,16 +60,19 @@ class PPEG(nn.Module):
 class TransMIL(nn.Module):
     def __init__(self, n_classes):
         super(TransMIL, self).__init__()
-        in_features = 1024
+        in_features = 2048
+        inter_features = 1024
         out_features = 512
-        self.pos_layer = PPEG(dim=out_features)
-        self._fc1 = nn.Sequential(nn.Linear(in_features, out_features), nn.GELU())
-        # self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU())
         if apex_available: 
             norm_layer = apex.normalization.FusedLayerNorm
-            
         else:
             norm_layer = nn.LayerNorm
+
+        self.pos_layer = PPEG(dim=out_features)
+        self._fc1 = nn.Sequential(nn.Linear(in_features, inter_features), nn.GELU(), nn.Dropout(p=0.5), norm_layer(inter_features)) 
+        self._fc1_2 = nn.Sequential(nn.Linear(inter_features, out_features), nn.GELU())
+        # self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU())
+        
         self.cls_token = nn.Parameter(torch.randn(1, 1, out_features))
         self.n_classes = n_classes
         self.layer1 = TransLayer(norm_layer=norm_layer, dim=out_features)
@@ -90,8 +94,10 @@ class TransMIL(nn.Module):
     def forward(self, x): #, **kwargs
 
         # x = self.model_ft(x).unsqueeze(0)
-        h = x.float() #[B, n, 1024]
+        h = x.squeeze(0).float() #[B, n, 1024]
         h = self._fc1(h) #[B, n, 512]
+        # h = self.drop(h)
+        h = self._fc1_2(h) #[B, n, 512]
         
         # print('Feature Representation: ', h.shape)
         #---->duplicate pad
@@ -110,7 +116,8 @@ class TransMIL(nn.Module):
 
 
         #---->Translayer x1
-        h, attn1 = self.layer1(h) #[B, N, 512]
+        h = self.layer1(h) #[B, N, 512]
+        # h, attn1 = self.layer1(h) #[B, N, 512]
 
         # print('After first TransLayer: ', h.shape)
 
@@ -119,7 +126,8 @@ class TransMIL(nn.Module):
         # print('After PPEG: ', h.shape)
         
         #---->Translayer x2
-        h, attn2 = self.layer2(h) #[B, N, 512]
+        h = self.layer2(h) #[B, N, 512]
+        # h, attn2 = self.layer2(h) #[B, N, 512]
 
         # print('After second TransLayer: ', h.shape) #[1, 1025, 512] 1025 = cls_token + 1024
         #---->cls_token
@@ -128,7 +136,11 @@ class TransMIL(nn.Module):
 
         #---->predict
         logits = self._fc2(h) #[B, n_classes]
-        return logits, attn2
+        # Y_hat = torch.argmax(logits, dim=1)
+        # Y_prob = F.softmax(logits, dim = 1)
+        # results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat}
+        return logits
+        # return logits, attn2
 
 if __name__ == "__main__":
     data = torch.randn((1, 6000, 1024)).cuda()
diff --git a/code/models/__pycache__/AttMIL.cpython-39.pyc b/code/models/__pycache__/AttMIL.cpython-39.pyc
index 550f6876773393094b93f1d14f352eb024d7917a..5d379a46200a876d652766c8d970eef8af36a4a0 100644
Binary files a/code/models/__pycache__/AttMIL.cpython-39.pyc and b/code/models/__pycache__/AttMIL.cpython-39.pyc differ
diff --git a/code/models/__pycache__/TransMIL.cpython-39.pyc b/code/models/__pycache__/TransMIL.cpython-39.pyc
index 547c0a2e1594a0ef49a2c0482c066af1ca368fc9..21d7707719795a1335b1187bffed7fd0328ba3fb 100644
Binary files a/code/models/__pycache__/TransMIL.cpython-39.pyc and b/code/models/__pycache__/TransMIL.cpython-39.pyc differ
diff --git a/code/models/__pycache__/TransformerMIL.cpython-39.pyc b/code/models/__pycache__/TransformerMIL.cpython-39.pyc
index 7ed17970f205f815c2b3f9dc1b6529a296ec9548..f2c5bda0715cd2d8601f4a33e2187a553a3f322a 100644
Binary files a/code/models/__pycache__/TransformerMIL.cpython-39.pyc and b/code/models/__pycache__/TransformerMIL.cpython-39.pyc differ
diff --git a/code/models/__pycache__/model_interface.cpython-39.pyc b/code/models/__pycache__/model_interface.cpython-39.pyc
index 176ab832bfacbc7d53356fade7a947b9b78e4a7e..0466ab737a0a0da0cfa6b04ea69f2aef82561a6d 100644
Binary files a/code/models/__pycache__/model_interface.cpython-39.pyc and b/code/models/__pycache__/model_interface.cpython-39.pyc differ
diff --git a/code/models/model_interface.py b/code/models/model_interface.py
index 9f4f282f67a84c5482a3f30fcb40b6d95bdcdd7d..0186c069c9f62c945c40647906823ed424deb06f 100755
--- a/code/models/model_interface.py
+++ b/code/models/model_interface.py
@@ -10,6 +10,7 @@ from pathlib import Path
 from matplotlib import pyplot as plt
 import cv2
 from PIL import Image
+from pytorch_pretrained_vit import ViT
 
 #---->
 from MyOptimizer import create_optimizer
@@ -28,6 +29,13 @@ import torch.nn.functional as F
 import torchmetrics
 from torchmetrics.functional import stat_scores
 from torch import optim as optim
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+
+from monai.config import KeysCollection
+from monai.data import Dataset, load_decathlon_datalist
+from monai.data.wsi_reader import WSIReader
+from monai.metrics import Cumulative, CumulativeAverage
+from monai.networks.nets import milmodel
 
 # from sklearn.metrics import roc_curve, auc, roc_curve_score
 
@@ -69,11 +77,19 @@ class ModelInterface(pl.LightningModule):
         super(ModelInterface, self).__init__()
         self.save_hyperparameters()
         self.n_classes = model.n_classes
-        self.load_model()
-        self.loss = create_loss(loss, model.n_classes)
-        # self.loss = AUCM_MultiLabel(num_classes = model.n_classes, device=self.device)
+        
+        if model.name == 'AttTrans':
+            self.model = milmodel.MILModel(num_classes=self.n_classes, pretrained=True, mil_mode='att_trans', backbone_num_features=1024)
+        else: self.load_model()
+        # self.loss = create_loss(loss, model.n_classes)
+        # self.loss = 
+        if self.n_classes>2:
+            self.aucm_loss = AUCM_MultiLabel(num_classes = model.n_classes, device=self.device)
+        else:
+            self.aucm_loss = AUCMLoss()
         # self.asl = AsymmetricLossSingleLabel()
-        # self.loss = LabelSmoothingCrossEntropy(smoothing=0.1)
+        self.loss = LabelSmoothingCrossEntropy(smoothing=0.1)
+
         # self.loss = 
         # print(self.model)
         self.model_name = model.name
@@ -99,7 +115,7 @@ class ModelInterface(pl.LightningModule):
         # print(self.experiment)
         #---->Metrics
         if self.n_classes > 2: 
-            self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes)
+            self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average='macro')
             
             metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes,
                                                                            average='weighted'),
@@ -131,7 +147,9 @@ class ModelInterface(pl.LightningModule):
         # self.pr_curve = torchmetrics.BinnedPrecisionRecallCurve(num_classes = self.n_classes, thresholds=10)
         self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes)                                                                    
         self.valid_metrics = metrics.clone(prefix = 'val_')
+        self.valid_patient_metrics = metrics.clone(prefix = 'val_patient_')
         self.test_metrics = metrics.clone(prefix = 'test_')
+        self.test_patient_metrics = metrics.clone(prefix = 'test_patient')
 
         #--->random
         self.shuffle = kargs['data'].data_shuffle
@@ -146,12 +164,16 @@ class ModelInterface(pl.LightningModule):
             self.feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
             self.model_ft = ViTModel.from_pretrained('facebook/dino-vitb16')
         elif self.backbone == 'resnet18':
-            self.model_ft = models.resnet18(pretrained=True)
+            self.model_ft = models.resnet18(weights='IMAGENET1K_V1')
             # modules = list(resnet18.children())[:-1]
+            # frozen_layers = 8
+            # for child in self.model_ft.children():
+
             for param in self.model_ft.parameters():
                 param.requires_grad = False
             self.model_ft.fc = nn.Linear(512, self.out_features)
 
+
             # res18 = nn.Sequential(
             #     *modules,
             # )
@@ -235,22 +257,28 @@ class ModelInterface(pl.LightningModule):
 
     def forward(self, x):
         # print(x.shape)
+        if self.model_name == 'AttTrans':
+            return self.model(x)
         if self.model_ft:
+            x = x.squeeze(0)
             feats = self.model_ft(x).unsqueeze(0)
         else: 
             feats = x.unsqueeze(0)
+        
         return self.model(feats)
         # return self.model(x)
 
     def step(self, input):
 
-        input = input.squeeze(0).float()
-        logits, _ = self(input.contiguous()) 
+        input = input.float()
+        # logits, _ = self(input.contiguous()) 
+        logits = self(input.contiguous())
+        Y_hat = torch.argmax(logits, dim=1)
+        Y_prob = F.softmax(logits, dim = 1)
 
-        
 
-        Y_hat = torch.argmax(logits, dim=1)
-        Y_prob = F.softmax(logits, dim=1)
+        # Y_hat = torch.argmax(logits, dim=1)
+        # Y_prob = F.softmax(logits, dim=1)
 
         return logits, Y_prob, Y_hat
 
@@ -264,12 +292,19 @@ class ModelInterface(pl.LightningModule):
         # bag_idxs = torch.randperm(input.squeeze(0).shape[0])[:bag_size]
         # input = input.squeeze(0)[bag_idxs].unsqueeze(0)
 
-        label = label.float()
+        # label = label.float()
         
         logits, Y_prob, Y_hat = self.step(input) 
 
         #---->loss
         loss = self.loss(logits, label)
+
+        one_hot_label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+        # aucm_loss = self.aucm_loss(torch.sigmoid(logits), one_hot_label)
+        # total_loss = torch.mean(loss + aucm_loss)
+        Y = int(label)
+        # print(logits, label)
+        # loss = cross_entropy_torch(logits.squeeze(0), label)
         # loss = self.asl(logits, label.squeeze())
 
         #---->acc log
@@ -278,11 +313,14 @@ class ModelInterface(pl.LightningModule):
         # if self.n_classes == 2:
         #     Y = int(label[0][1])
         # else: 
-        Y = torch.argmax(label)
+        # Y = torch.argmax(label)
+        
             # Y = int(label[0])
         self.data[Y]["count"] += 1
         self.data[Y]["correct"] += (int(Y_hat) == Y)
-        self.log('loss', loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1, sync_dist=True)
+        # self.log('total_loss', total_loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1, sync_dist=True)
+        # self.log('aucm_loss', aucm_loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1, sync_dist=True)
+        self.log('lsce_loss', loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1, sync_dist=True)
 
         # if self.current_epoch % 10 == 0:
 
@@ -298,7 +336,7 @@ class ModelInterface(pl.LightningModule):
         #     self.loggers[0].experiment.add_image(f'{self.current_epoch}/input', grid)
 
 
-        return {'loss': loss, 'Y_prob': Y_prob, 'Y_hat': Y_hat, 'label': Y} 
+        return {'loss': loss, 'Y_prob': Y_prob, 'Y_hat': Y_hat, 'label': label} 
 
     def training_epoch_end(self, training_step_outputs):
         # logits = torch.cat([x['logits'] for x in training_step_outputs], dim = 0)
@@ -324,50 +362,65 @@ class ModelInterface(pl.LightningModule):
         if self.current_epoch % 10 == 0:
             self.log_confusion_matrix(max_probs, target, stage='train')
 
-        self.log('Train/auc', self.AUROC(probs, target), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+        self.log('Train/auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
 
     def validation_step(self, batch, batch_idx):
 
-        input, label, _ = batch
-        label = label.float()
+        input, label, (wsi_name, batch_names, patient) = batch
+        # label = label.float()
         
         logits, Y_prob, Y_hat = self.step(input) 
 
         #---->acc log
         # Y = int(label[0][1])
-        Y = torch.argmax(label)
+        # Y = torch.argmax(label)
+        loss = self.loss(logits, label)
+        # loss = self.loss(logits, label)
+        # print(loss)
+        Y = int(label)
 
         # print(Y_hat)
         self.data[Y]["count"] += 1
         self.data[Y]["correct"] += (int(Y_hat) == Y)
+        
         # self.data[Y]["correct"] += (Y_hat.item() == Y)
 
-        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y}
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': wsi_name, 'patient': patient, 'loss':loss}
 
 
     def validation_epoch_end(self, val_step_outputs):
+
+        # print(val_step_outputs)
+        # print(torch.cat([x['Y_prob'] for x in val_step_outputs], dim=0))
+        # print(torch.stack([x['Y_prob'] for x in val_step_outputs]))
+        
         logits = torch.cat([x['logits'] for x in val_step_outputs], dim = 0)
         probs = torch.cat([x['Y_prob'] for x in val_step_outputs])
         max_probs = torch.stack([x['Y_hat'] for x in val_step_outputs])
-        target = torch.stack([x['label'] for x in val_step_outputs])
+        target = torch.stack([x['label'] for x in val_step_outputs], dim=0).int()
+        slide_names = [x['name'] for x in val_step_outputs]
+        patients = [x['patient'] for x in val_step_outputs]
+
+        loss = torch.stack([x['loss'] for x in val_step_outputs])
+        # loss = torch.cat([x['loss'] for x in val_step_outputs])
+        # print(loss.shape)
         
-        self.log_dict(self.valid_metrics(logits, target),
+
+        # self.log('val_loss', cross_entropy_torch(logits.squeeze(), target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+        self.log('val_loss', loss, prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+        
+        # print(logits)
+        # print(target)
+        self.log_dict(self.valid_metrics(max_probs.squeeze(), target.squeeze()),
                           on_epoch = True, logger = True, sync_dist=True)
         
-        #---->
-        # logits = logits.long()
-        # target = target.squeeze().long()
-        # logits = logits.squeeze(0)
+
         if len(target.unique()) != 1:
-            self.log('val_auc', self.AUROC(probs, target), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+            self.log('val_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+            # self.log('val_patient_auc', self.AUROC(patient_score, patient_target), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
         else:    
             self.log('val_auc', 0.0, prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
 
-        self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
-        
-
-        precision, recall, thresholds = self.PRC(probs, target)
-
 
 
         # print(max_probs.squeeze(0).shape)
@@ -376,6 +429,62 @@ class ModelInterface(pl.LightningModule):
 
         #----> log confusion matrix
         self.log_confusion_matrix(max_probs, target, stage='val')
+
+        #----> log per patient metrics
+        complete_patient_dict = {}
+        patient_list = []            
+        patient_score = []      
+        patient_target = []
+
+        for p, s, pr, t in zip(patients, slide_names, probs, target):
+            if p not in complete_patient_dict.keys():
+                complete_patient_dict[p] = [(s, pr)]
+                patient_target.append(t)
+            else:
+                complete_patient_dict[p].append((s, pr))
+
+       
+
+        for p in complete_patient_dict.keys():
+            score = []
+            for (slide, probs) in complete_patient_dict[p]:
+                # max_probs = torch.argmax(probs)
+                # if self.n_classes == 2:
+                #     score.append(max_probs)
+                # else: score.append(probs)
+                score.append(probs)
+
+            # if self.n_classes == 2:
+                # score =
+            score = torch.mean(torch.stack(score), dim=0) #.cpu().detach().numpy()
+            # complete_patient_dict[p]['score'] = score
+            # print(p, score)
+            # patient_list.append(p)    
+            patient_score.append(score)    
+
+        patient_score = torch.stack(patient_score)
+        # print(patient_target)
+        # print(torch.cat(patient_target))
+        # print(self.AUROC(patient_score.squeeze(), torch.cat(patient_target)))
+
+        
+        patient_target = torch.cat(patient_target)
+
+        # print(patient_score.shape)
+        # print(patient_target.shape)
+        
+        if len(patient_target.unique()) != 1:
+            self.log('val_patient_auc', self.AUROC(patient_score.squeeze(), patient_target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+        else:    
+            self.log('val_patient_auc', 0.0, prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+        
+        self.log_dict(self.valid_patient_metrics(patient_score, patient_target),
+                          on_epoch = True, logger = True, sync_dist=True)
+        
+            
+
+        # precision, recall, thresholds = self.PRC(probs, target)
+
         
 
         #---->acc log
@@ -394,178 +503,117 @@ class ModelInterface(pl.LightningModule):
             self.count = self.count+1
             random.seed(self.count*50)
 
-    def test_step(self, batch, batch_idx):
 
-        torch.set_grad_enabled(True)
-        data, label, (wsi_name, batch_names) = batch
-        wsi_name = wsi_name[0]
-        label = label.float()
-        # logits, Y_prob, Y_hat = self.step(data) 
-        # print(data.shape)
-        data = data.squeeze(0).float()
-        logits, attn = self(data)
-        attn = attn.detach()
-        logits = logits.detach()
-
-        Y = torch.argmax(label)
-        Y_hat = torch.argmax(logits, dim=1)
-        Y_prob = F.softmax(logits, dim = 1)
-        
-        #----> Get GradCam maps, map each instance to attention value, assemble, overlay on original WSI 
-        if self.model_name == 'TransMIL':
-           
-            target_layers = [self.model.layer2.norm] # 32x32
-            # target_layers = [self.model_ft[0].features[-1]] # 32x32
-            self.cam = GradCAM(model=self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform) #, reshape_transform=self.reshape_transform
-            # self.cam_ft = GradCAM(model=self.model, target_layers = target_layers_ft, use_cuda=True) #, reshape_transform=self.reshape_transform
-        else:
-            target_layers = [self.model.attention_weights]
-            self.cam = GradCAM(model = self.model, target_layers = target_layers, use_cuda=True)
-
-
-        data_ft = self.model_ft(data).unsqueeze(0).float()
-        instance_count = data.size(0)
-        target = [ClassifierOutputTarget(Y)]
-        grayscale_cam = self.cam(input_tensor=data_ft, targets=target)
-        grayscale_cam = torch.Tensor(grayscale_cam)[:instance_count, :]
-
-        # attention_map = grayscale_cam[:, :, 1].squeeze()
-        # attention_map = F.relu(attention_map)
-        # mask = torch.zeros((instance_count, 3, 256, 256)).to(self.device)
-        # for i, v in enumerate(attention_map):
-        #     mask[i, :, :, :] = v
-
-        # mask = self.assemble(mask, batch_names)
-        # mask = (mask - mask.min())/(mask.max()-mask.min())
-        # mask = mask.cpu().numpy()
-        # wsi = self.assemble(data, batch_names)
-        # wsi = wsi.cpu().numpy()
-
-        # def show_cam_on_image(img, mask):
-        #     heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
-        #     heatmap = np.float32(heatmap) / 255
-        #     cam = heatmap*0.4 + np.float32(img)
-        #     cam = cam / np.max(cam)
-        #     return cam
-
-        # wsi = show_cam_on_image(wsi, mask)
-        # wsi = ((wsi-wsi.min())/(wsi.max()-wsi.min()) * 255.0).astype(np.uint8)
-        
-        # img = Image.fromarray(wsi)
-        # img = img.convert('RGB')
-        
-
-        # output_path = self.save_path / str(Y.item())
-        # output_path.mkdir(parents=True, exist_ok=True)
-        # img.save(f'{output_path}/{wsi_name}.jpg')
 
+    def test_step(self, batch, batch_idx):
 
-        #----> Get Topk Tiles and Topk Patients
-        summed = torch.mean(grayscale_cam, dim=2)
-        topk_tiles, topk_indices = torch.topk(summed.squeeze(0), 5, dim=0)
-        topk_data = data[topk_indices].detach()
+        input, label, (wsi_name, batch_names, patient) = batch
+        label = label.float()
         
-        # target_ft = 
-        # grayscale_cam_ft = self.cam_ft(input_tensor=data, )
-        # for i in range(data.shape[0]):
-            
-            # vis_img = data[i, :, :, :].cpu().numpy()
-            # vis_img = np.transpose(vis_img, (1,2,0))
-            # print(vis_img.shape)
-            # cam_img = grayscale_cam.squeeze(0)
-        # cam_img = self.reshape_transform(grayscale_cam)
-
-        # print(cam_img.shape)
-            
-            # visualization = show_cam_on_image(vis_img, cam_img, use_rgb=True)
-            # visualization = ((visualization/visualization.max())*255.0).astype(np.uint8)
-            # print(visualization)
-        # cv2.imwrite(f'{test_path}/{Y}/{name}/gradcam.jpg', cam_img)
+        logits, Y_prob, Y_hat = self.step(input) 
 
         #---->acc log
-        Y = torch.argmax(label)
+        Y = int(label)
+        # Y = torch.argmax(label)
+
+        # print(Y_hat)
         self.data[Y]["count"] += 1
-        self.data[Y]["correct"] += (Y_hat.item() == Y)
+        self.data[Y]["correct"] += (int(Y_hat) == Y)
+        # self.data[Y]["correct"] += (Y_hat.item() == Y)
 
-        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y, 'name': wsi_name, 'topk_data': topk_data} #
-        # return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': name} #, 'topk_data': topk_data
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': wsi_name, 'patient': patient}
 
     def test_epoch_end(self, output_results):
         logits = torch.cat([x['logits'] for x in output_results], dim = 0)
         probs = torch.cat([x['Y_prob'] for x in output_results])
         max_probs = torch.stack([x['Y_hat'] for x in output_results])
-        # target = torch.stack([x['label'] for x in output_results], dim = 0)
-        target = torch.stack([x['label'] for x in output_results])
-        # target = torch.argmax(target, dim=1)
-        patients = [x['name'] for x in output_results]
-        topk_tiles = [x['topk_data'] for x in output_results]
-        #---->
-        auc = self.AUROC(probs, target)
-        fpr, tpr, thresholds = self.ROC(probs, target)
-        fpr = fpr.cpu().numpy()
-        tpr = tpr.cpu().numpy()
+        target = torch.stack([x['label'] for x in output_results]).int()
+        slide_names = [x['name'] for x in output_results]
+        patients = [x['patient'] for x in output_results]
+        
+        self.log_dict(self.test_metrics(max_probs.squeeze(), target.squeeze()),
+                          on_epoch = True, logger = True, sync_dist=True)
+        self.log('test_loss', cross_entropy_torch(logits.squeeze(), target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
 
-        plt.figure(1)
-        plt.plot(fpr, tpr)
-        plt.xlabel('False positive rate')
-        plt.ylabel('True positive rate')
-        plt.title('ROC curve')
-        plt.savefig(f'{self.save_path}/roc.jpg')
-        # self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+        if len(target.unique()) != 1:
+            self.log('test_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+            # self.log('val_patient_auc', self.AUROC(patient_score, patient_target), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+        else:    
+            self.log('test_auc', 0.0, prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
 
-        metrics = self.test_metrics(logits , target)
 
 
-        # metrics = self.test_metrics(max_probs.squeeze() , torch.argmax(target.squeeze(), dim=1))
-        metrics['test_auc'] = auc
+        #----> log confusion matrix
+        self.log_confusion_matrix(max_probs, target, stage='test')
 
-        # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True)
+        #----> log per patient metrics
+        complete_patient_dict = {}
+        patient_list = []            
+        patient_score = []      
+        patient_target = []
+        patient_class_score = 0
+
+        for p, s, pr, t in zip(patients, slide_names, probs, target):
+            if p not in complete_patient_dict.keys():
+                complete_patient_dict[p] = [(s, pr)]
+                patient_target.append(t)
+            else:
+                complete_patient_dict[p].append((s, pr))
 
-        #---->get highest scoring patients for each class
-        # test_path = Path(self.save_path) / 'most_predictive' 
-        
-        # Path.mkdir(output_path, exist_ok=True)
-        topk, topk_indices = torch.topk(probs.squeeze(0), 5, dim=0)
-        for n in range(self.n_classes):
-            print('class: ', n)
-            
-            topk_patients = [patients[i[n]] for i in topk_indices]
-            topk_patient_tiles = [topk_tiles[i[n]] for i in topk_indices]
-            for x, p, t in zip(topk, topk_patients, topk_patient_tiles):
-                print(p, x[n])
-                patient = p
-                # outpath = test_path / str(n) / patient 
-                outpath = Path(self.save_path) / str(n) / patient
-                outpath.mkdir(parents=True, exist_ok=True)
-                for i in range(len(t)):
-                    tile = t[i]
-                    tile = tile.cpu().numpy().transpose(1,2,0)
-                    tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255
-                    tile = tile.astype(np.uint8)
-                    img = Image.fromarray(tile)
+       
+
+        for p in complete_patient_dict.keys():
+            score = []
+            for (slide, probs) in complete_patient_dict[p]:
+                # if self.n_classes == 2:
+                #     if probs.argmax().item() == 1: # only if binary and if class 1 is more important!!! Normal vs Diseased or Rejection vs Other
+                #         score.append(probs)
                     
-                    img.save(f'{outpath}/{i}.jpg')
+                # else: 
+                score.append(probs)
+            # print(score)
+            score = torch.stack(score)
+            # print(score)
+            if self.n_classes == 2:
+                positive_positions = (score.argmax(dim=1) == 1).nonzero().squeeze()
+                if positive_positions.numel() != 0:
+                    score = score[positive_positions]
+            else:
+            # score = torch.stack(torch.score)
+            ## get scores that predict class 1:
+            # positive_scores = score.argmax(dim=1)
+            # score = torch.sum(score.argmax(dim=1))
+
+            # if score.item() == 1:
+            #     patient_class_score = 1
+                score = torch.mean(score) #.cpu().detach().numpy()
+            # complete_patient_dict[p]['score'] = score
+            # print(p, score)
+            # patient_list.append(p)    
+            patient_score.append(score)    
+
+        print(patient_score)
+
+        patient_score = torch.stack(patient_score)
+        # patient_target = torch.stack(patient_target)
+        patient_target = torch.cat(patient_target)
 
-            
-            
-        #----->visualize top predictive tiles
         
+        if len(patient_target.unique()) != 1:
+            self.log('test_patient_auc', self.AUROC(patient_score.squeeze(), patient_target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+        else:    
+            self.log('test_patient_auc', 0.0, prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
         
-
+        self.log_dict(self.test_patient_metrics(patient_score, patient_target),
+                          on_epoch = True, logger = True, sync_dist=True)
         
-                # img = img.squeeze(0).cpu().numpy()
-                # img = np.transpose(img, (1,2,0))
-                # # print(img)
-                # # print(grayscale_cam.shape)
-                # visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True)
-
+            
 
-        for keys, values in metrics.items():
-            print(f'{keys} = {values}')
-            metrics[keys] = values.cpu().numpy()
-        #---->acc log
+        # precision, recall, thresholds = self.PRC(probs, target)
 
+        
 
+        #---->acc log
         for c in range(self.n_classes):
             count = self.data[c]["count"]
             correct = self.data[c]["correct"]
@@ -573,37 +621,25 @@ class ModelInterface(pl.LightningModule):
                 acc = None
             else:
                 acc = float(correct) / count
-            print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
+            print('test class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
         self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
-
-        #---->plot auroc curve
-        # stats = stat_scores(probs, target, reduce='macro', num_classes=self.n_classes)
-        # fpr = {}
-        # tpr = {}
-        # for n in self.n_classes: 
-
-        # fpr, tpr, thresh = roc_curve(target.cpu().numpy(), probs.cpu().numpy())
-        #[tp, fp, tn, fn, tp+fn]
-
-
-        self.log_confusion_matrix(max_probs, target, stage='test')
-        #---->
-        result = pd.DataFrame([metrics])
-        result.to_csv(Path(self.save_path) / f'test_result.csv', mode='a', header=not Path(self.save_path).exists())
-
-        # with open(f'{self.save_path}/test_metrics.txt', 'a') as f:
-
-        #     f.write([metrics])
+        
+        #---->random, if shuffle data, change seed
+        if self.shuffle == True:
+            self.count = self.count+1
+            random.seed(self.count*50)
 
     def configure_optimizers(self):
         # optimizer_ft = optim.Adam(self.model_ft.parameters(), lr=self.optimizer.lr*0.1)
         optimizer = create_optimizer(self.optimizer, self.model)
-        # optimizer = PESG(self.model, a=self.loss.a, b=self.loss.b, loss_fn=self.loss, lr=self.optimizer.lr, margin=1.0, epoch_decay=2e-3, weight_decay=1e-5, device=self.device)
+        # optimizer = PESG(self.model, loss_fn=self.aucm_loss, lr=self.optimizer.lr, margin=1.0, epoch_decay=2e-3, weight_decay=1e-5, device=self.device)
         # optimizer = PDSCA(self.model, loss_fn=self.loss, lr=self.optimizer.lr, margin=1.0, epoch_decay=2e-3, weight_decay=1e-5, device=self.device)
-        return optimizer     
+        scheduler = {'scheduler': ReduceLROnPlateau(optimizer, mode='min', factor=0.5), 'monitor': 'val_loss', 'frequency': 5}
+        
+        return [optimizer], [scheduler]     
 
-    def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
-        optimizer.zero_grad(set_to_none=True)
+    # def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
+    #     optimizer.zero_grad(set_to_none=True)
 
     def reshape_transform(self, tensor):
         # print(tensor.shape)
@@ -618,10 +654,12 @@ class ModelInterface(pl.LightningModule):
 
     def load_model(self):
         name = self.hparams.model.name
-        backbone = self.hparams.model.backbone
         # Change the `trans_unet.py` file name to `TransUnet` class name.
         # Please always name your model file name as `trans_unet.py` and
         # class name or funciton name corresponding `TransUnet`.
+        if name == 'ViT':
+            self.model = ViT
+
         if '_' in name:
             camel_name = ''.join([i.capitalize() for i in name.split('_')])
         else:
@@ -686,7 +724,7 @@ class ModelInterface(pl.LightningModule):
         if stage == 'train':
             self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
         else:
-            fig_.savefig(f'{self.loggers[0].log_dir}/cm_test.png', dpi=400)
+            fig_.savefig(f'{self.loggers[0].log_dir}/cm_{stage}.png', dpi=400)
 
         fig_.clf()
 
diff --git a/code/monai_test.json b/code/monai_test.json
new file mode 100644
index 0000000000000000000000000000000000000000..093623db790c7bbb554ac717a04138ab5ac53d24
--- /dev/null
+++ b/code/monai_test.json
@@ -0,0 +1 @@
+{"training": [{"image": "Aachen_KiBiDatabase_KiBiAcRCIQ360_01_018_PAS.svs", "label": 0}, {"image": "Aachen_KiBiDatabase_KiBiAcRCIQ360_01_018_PAS.svs", "label": 0}, {"image": "Aachen_KiBiDatabase_KiBiAcRCIQ360_01_018_PAS.svs", "label": 0}], "validation": [{"image": "Aachen_KiBiDatabase_KiBiAcRCIQ360_01_018_PAS.svs", "label": 0}]}
\ No newline at end of file
diff --git a/code/test_visualize.py b/code/test_visualize.py
index 5e4a5e76491d1725a78c90216f077c4eac32bd17..79f814cd75aa73b498cfb140d36ad78a8d0b1568 100644
--- a/code/test_visualize.py
+++ b/code/test_visualize.py
@@ -38,7 +38,7 @@ def make_parse():
     parser.add_argument('--config', default='../DeepGraft/TransMIL.yaml',type=str)
     parser.add_argument('--version', default=0,type=int)
     parser.add_argument('--epoch', default='0',type=str)
-    parser.add_argument('--gpus', default = 2, type=int)
+    parser.add_argument('--gpus', default = 0, type=int)
     parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str)
     parser.add_argument('--fold', default = 0)
     parser.add_argument('--bag_size', default = 10000, type=int)
@@ -54,6 +54,7 @@ class custom_test_module(ModelInterface):
     def test_step(self, batch, batch_idx):
 
         torch.set_grad_enabled(True)
+
         input_data, label, (wsi_name, batch_names, patient) = batch
         patient = patient[0]
         wsi_name = wsi_name[0]
@@ -61,14 +62,18 @@ class custom_test_module(ModelInterface):
         # logits, Y_prob, Y_hat = self.step(data) 
         # print(data.shape)
         input_data = input_data.squeeze(0).float()
-        logits, attn = self(input_data)
-        attn = attn.detach()
-        logits = logits.detach()
+        # print(self.model_ft)
+        # print(self.model)
+        logits, _ = self(input_data)
+        # attn = attn.detach()
+        # logits = logits.detach()
 
         Y = torch.argmax(label)
         Y_hat = torch.argmax(logits, dim=1)
         Y_prob = F.softmax(logits, dim=1)
 
+        
+
         # print('Y_hat:', Y_hat)
         # print('Y_prob:', Y_prob)
 
@@ -87,9 +92,16 @@ class custom_test_module(ModelInterface):
             target_layers = [self.model.attention_weights]
             self.cam = GradCAM(model = self.model, target_layers = target_layers, use_cuda=True)
 
-        data_ft = self.model_ft(input_data).unsqueeze(0).float()
+        if self.model_ft:
+            data_ft = self.model_ft(input_data).unsqueeze(0).float()
+        else:
+            data_ft = input_data.unsqueeze(0).float()
         instance_count = input_data.size(0)
+        # data_ft.requires_grad=True
+        
         target = [ClassifierOutputTarget(Y)]
+        # print(target)
+        
         grayscale_cam = self.cam(input_tensor=data_ft, targets=target, eigen_smooth=True)
         grayscale_cam = torch.Tensor(grayscale_cam)[:instance_count, :] #.to(self.device)
 
@@ -100,6 +112,7 @@ class custom_test_module(ModelInterface):
         summed = torch.mean(grayscale_cam, dim=2)
         topk_tiles, topk_indices = torch.topk(summed.squeeze(0), k, dim=0)
         topk_data = input_data[topk_indices].detach()
+        # print(topk_tiles)
         
         #----------------------------------------------------
         # Log Correct/Count
@@ -115,7 +128,7 @@ class custom_test_module(ModelInterface):
         # print(input_data.shape)
         # print(len(batch_names))
         # if visualize:
-        #     self.save_attention_map(wsi_name, input_data, batch_names, grayscale_cam, target=Y)
+        # self.save_attention_map(wsi_name, batch_names, grayscale_cam, target=Y)
         # print('test_step_patient: ', patient)
 
         return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y, 'name': wsi_name, 'patient': patient, 'topk_data': topk_data} #
@@ -128,7 +141,6 @@ class custom_test_module(ModelInterface):
 
         pp = pprint.PrettyPrinter(indent=4)
 
-
         logits = torch.cat([x['logits'] for x in output_results], dim = 0)
         probs = torch.cat([x['Y_prob'] for x in output_results])
         max_probs = torch.stack([x['Y_hat'] for x in output_results])
@@ -158,7 +170,6 @@ class custom_test_module(ModelInterface):
         '''
         Patient
         -> slides:
-            
             -> SlideName:
                 ->probs = [0.5, 0.5] 
                 ->topk = [10,3,224,224]
@@ -180,11 +191,11 @@ class custom_test_module(ModelInterface):
                 score.append(complete_patient_dict[p]['slides'][s]['probs'])
             score = torch.mean(torch.stack(score), dim=0) #.cpu().detach().numpy()
             complete_patient_dict[p]['score'] = score
-            print(p, score)
+            # print(p, score)
             patient_list.append(p)    
             patient_score.append(score)    
 
-        print(patient_list)
+        # print(patient_list)
         #topk patients: 
 
 
@@ -212,37 +223,34 @@ class custom_test_module(ModelInterface):
             output_dict[class_name] = {}
             # class_name = str(n)
             print('class: ', class_name)
-            print(score)
+            # print(score)
             _, topk_indices = torch.topk(score, k_patient, dim=0) # change to 3
-            print(topk_indices)
+            # print(topk_indices)
 
             topk_patients = [patient_list[i] for i in topk_indices]
 
             patient_top_slides = {} 
             for p in topk_patients:
-                print(p)
+                # print(p)
                 output_dict[class_name][p] = {}
                 output_dict[class_name][p]['Patient_Score'] = complete_patient_dict[p]['score'].cpu().detach().numpy().tolist()
 
                 slides = list(complete_patient_dict[p]['slides'].keys())
                 slide_scores = [complete_patient_dict[p]['slides'][s]['probs'] for s in slides]
                 slide_scores = torch.stack(slide_scores)
-                print(slide_scores)
+                # print(slide_scores)
                 _, topk_slide_indices = torch.topk(slide_scores, k_slide, dim=0)
                 # topk_slide_indices = topk_slide_indices.squeeze(0)
-                print(topk_slide_indices[0])
+                # print(topk_slide_indices[0])
                 topk_patient_slides = [slides[i] for i in topk_slide_indices[0]]
                 patient_top_slides[p] = topk_patient_slides
 
                 output_dict[class_name][p]['Top_Slides'] = [{slides[i]: {'Slide_Score': slide_scores[i].cpu().detach().numpy().tolist()}} for i in topk_slide_indices[0]]
-            
-
-            
 
             for p in topk_patients: 
 
                 score = complete_patient_dict[p]['score']
-                print(p, score)
+                # print(p, score)
                 print('Topk Slides:')
                 for slide in patient_top_slides[p]:
                     print(slide)
@@ -250,21 +258,18 @@ class custom_test_module(ModelInterface):
                     outpath.mkdir(parents=True, exist_ok=True)
                 
                     topk_tiles = complete_patient_dict[p]['slides'][slide]['topk']
-                    for i in range(topk_tiles.shape[0]):
-                        tile = topk_tiles[i]
-                        tile = tile.cpu().numpy().transpose(1,2,0)
-                        tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255
-                        tile = tile.astype(np.uint8)
-                        img = Image.fromarray(tile)
+                    # for i in range(topk_tiles.shape[0]):
+                    #     tile = topk_tiles[i]
+                    #     tile = tile.cpu().numpy().transpose(1,2,0)
+                    #     tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255
+                    #     tile = tile.astype(np.uint8)
+                    #     img = Image.fromarray(tile)
                     
-                    img.save(f'{outpath}/{i}.jpg')
+                    #     img.save(f'{outpath}/{i}.jpg')
         output_dict['Test_Metrics'] = np_metrics
         pp.pprint(output_dict)
         json.dump(output_dict, open(f'{self.save_path}/test_metrics.json', 'w'))
 
-
-        
-
         for keys, values in metrics.items():
             print(f'{keys} = {values}')
             metrics[keys] = values.cpu().numpy()
@@ -286,20 +291,35 @@ class custom_test_module(ModelInterface):
         result = pd.DataFrame([metrics])
         result.to_csv(Path(self.save_path) / f'test_result.csv', mode='a', header=not Path(self.save_path).exists())
 
-    def save_attention_map(self, wsi_name, data, batch_names, grayscale_cam, target):
+    def save_attention_map(self, wsi_name, batch_names, grayscale_cam, target):
 
-        def get_coords(batch_names): #ToDO: Change function for precise coords
-            coords = []
+        # def get_coords(batch_names): #ToDO: Change function for precise coords
+        #     coords = []
             
-            for tile_name in batch_names: 
-                pos = re.findall(r'\((.*?)\)', tile_name[0])
-                x, y = pos[-1].split('_')
-                coords.append((int(x),int(y)))
-            return coords
-        
-        coords = get_coords(batch_names)
+        #     for tile_name in batch_names: 
+        #         pos = re.findall(r'\((.*?)\)', tile_name[0])
+        #         x, y = pos[-1].split('_')
+        #         coords.append((int(x),int(y)))
+        #     return coords
+
+        home = Path.cwd().parts[1]
+        jpg_dir = f'/{home}/ylan/data/DeepGraft/224_128um_annotated/Aachen_Biopsy_Slides/BLOCKS'
+
+        coords = batch_names.squeeze()
+        data = []
+        for co in coords:
+
+            tile_path =  Path(jpg_dir) / wsi_name / f'{wsi_name}_({co[0]}_{co[1]}).jpg'
+            img = np.asarray(Image.open(tile_path)).astype(np.uint8)
+            img = torch.from_numpy(img)
+            # print(img.shape)
+            data.append(img)
         # coords_set = set(coords)
-
+        # data = data.unsqueeze(0)
+        # print(data.shape)
+        data = torch.stack(data)
+        # print(data.max())
+        # print(data.min())
         # print(coords)
         # temp_data = data.cpu()
         # print(data.shape)
@@ -307,7 +327,7 @@ class custom_test_module(ModelInterface):
         # wsi = (wsi-wsi.min())/(wsi.max()-wsi.min())
         # wsi = wsi
         # print(coords)
-        print('wsi.shape: ', wsi.shape)
+        # print('wsi.shape: ', wsi.shape)
         #--> Get interpolated mask from GradCam
         W, H = wsi.shape[0], wsi.shape[1]
         
@@ -318,8 +338,8 @@ class custom_test_module(ModelInterface):
         input_h = 224
         
         mask = torch.ones(( int(W/input_h), int(H/input_h))).to(self.device)
-        print('mask.shape: ', mask.shape)
-        print('attention_map.shape: ', attention_map.shape)
+        # print('mask.shape: ', mask.shape)
+        # print('attention_map.shape: ', attention_map.shape)
         for i, (x,y) in enumerate(coords):
             mask[y][x] = attention_map[i]
         mask = mask.unsqueeze(0).unsqueeze(0)
@@ -343,12 +363,12 @@ class custom_test_module(ModelInterface):
         
         size = (20000, 20000)
 
-        img = Image.fromarray(wsi_cam)
-        img = img.convert('RGB')
-        img.thumbnail(size, Image.ANTIALIAS)
-        output_path = self.save_path / str(target.item())
-        output_path.mkdir(parents=True, exist_ok=True)
-        img.save(f'{output_path}/{wsi_name}_gradcam.jpg')
+        # img = Image.fromarray(wsi_cam)
+        # img = img.convert('RGB')
+        # img.thumbnail(size, Image.ANTIALIAS)
+        # output_path = self.save_path / str(target.item())
+        # output_path.mkdir(parents=True, exist_ok=True)
+        # img.save(f'{output_path}/{wsi_name}_gradcam.jpg')
 
         wsi = ((wsi-wsi.min())/(wsi.max()-wsi.min()) * 255.0).astype(np.uint8)
         img = Image.fromarray(wsi)
@@ -365,11 +385,11 @@ class custom_test_module(ModelInterface):
 
     def assemble(self, tiles, coords): # with coordinates (x-y)
         
-        def getPosition(img_name):
-            pos = re.findall(r'\((.*?)\)', img_name) #get strings in brackets (0-0)
-            a = int(pos[0].split('-')[0])
-            b = int(pos[0].split('-')[1])
-            return a, b
+        # def getPosition(img_name):
+        #     pos = re.findall(r'\((.*?)\)', img_name) #get strings in brackets (0-0)
+        #     a = int(pos[0].split('-')[0])
+        #     b = int(pos[0].split('-')[1])
+        #     return a, b
 
         position_dict = {}
         assembled = []
@@ -384,19 +404,23 @@ class custom_test_module(ModelInterface):
 
         for i, (x,y) in enumerate(coords):
             if x not in position_dict.keys():
-                position_dict[x] = [(y, i)]
-            else: position_dict[x].append((y, i))
+                position_dict[x.item()] = [(y.item(), i)]
+            else: position_dict[x.item()].append((y.item(), i))
         # x_positions = sorted(position_dict.keys())
 
         test_img_compl = torch.ones([(y_max+1)*224, (x_max+1)*224, 3]).to(self.device)
+
         for i in range(x_max+1):
             if i in position_dict.keys():
                 for j in position_dict[i]:
                     sample_idx = j[1]
-                    if tiles[sample_idx, :, :, :].shape != [3,224,224]:
-                        img = tiles[sample_idx, :, :, :].permute(1,2,0)
-                    else: 
-                        img = tiles[sample_idx, :, :, :]
+                    # if tiles[sample_idx, :, :, :].shape != [3,224,224]:
+                    #     img = tiles[sample_idx, :, :, :].permute(2,0,1)
+                    # else: 
+                    img = tiles[sample_idx, :, :, :]
+                    # print(img.shape)
+                    # print(img.max())
+                    # print(img.min())
                     y_coord = int(j[0])
                     x_coord = int(i)
                     test_img_compl[y_coord*224:(y_coord+1)*224, x_coord*224:(x_coord+1)*224, :] = img
@@ -453,6 +477,9 @@ def main(cfg):
     # cfg.Data.label_file = '/home/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_norm_rest.json'
     # cfg.Data.patient_slide = '/homeStor1/ylan/DeepGraft/training_tables/cohort_stain_dict.json'
     # cfg.Data.data_dir = '/homeStor1/ylan/data/DeepGraft/224_128um_v2/'
+    if cfg.Model.backbone == 'features':
+        use_features = True
+    else: use_features = False
     DataInterface_dict = {
                 'data_root': cfg.Data.data_dir,
                 'label_path': cfg.Data.label_file,
@@ -461,6 +488,7 @@ def main(cfg):
                 'n_classes': cfg.Model.n_classes,
                 'backbone': cfg.Model.backbone,
                 'bag_size': cfg.Data.bag_size,
+                'use_features': use_features,
                 }
 
     dm = MILDataModule(**DataInterface_dict)
@@ -489,7 +517,8 @@ def main(cfg):
         # callbacks=cfg.callbacks,
         max_epochs= cfg.General.epochs,
         min_epochs = 200,
-        gpus=cfg.General.gpus,
+        accelerator='gpu',
+        devices=cfg.General.gpus,
         # gpus = [0,2],
         # strategy='ddp',
         amp_backend='native',
@@ -508,7 +537,7 @@ def main(cfg):
     # log_path = Path('lightning_logs/2/checkpoints')
     model_paths = list(log_path.glob('*.ckpt'))
 
-
+    # print(model_paths)
     if cfg.epoch == 'last':
         model_paths = [str(model_path) for model_path in model_paths if f'last' in str(model_path)]
     else:
@@ -583,5 +612,121 @@ if __name__ == '__main__':
     
 
     #---->main
-    main(cfg)
+    # main(cfg)
+    from models import TransMIL
+    from datasets.zarr_feature_dataloader_simple import ZarrFeatureBagLoader
+    from datasets.feature_dataloader import FeatureBagLoader
+    from torch.utils.data import random_split, DataLoader
+    import time
+    from tqdm import tqdm
+    import torchmetrics
+
+    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+    print(device)
+    scaler = torch.cuda.amp.GradScaler()
+    
+    log_path = Path(cfg.log_path) / 'checkpoints'
+    model_paths = list(log_path.glob('*.ckpt'))
+
+    # print(model_paths)
+    if cfg.epoch == 'last':
+        model_paths = [str(model_path) for model_path in model_paths if f'last' in str(model_path)]
+    else:
+        model_paths = [str(model_path) for model_path in model_paths if f'epoch={cfg.epoch}' in str(model_path)]
+
+    # checkpoint = torch.load(f'{cfg.log_path}/checkpoints/epoch=04-val_loss=0.4243-val_auc=0.8243-val_patient_auc=0.8282244801521301.ckpt')
+    # checkpoint = torch.load(f'{cfg.log_path}/checkpoints/epoch=73-val_loss=0.8574-val_auc=0.9682-val_patient_auc=0.9724310636520386.ckpt')
+    checkpoint = torch.load(model_paths[0])
+
+    hyper_parameters = checkpoint['hyper_parameters']
+    n_classes = hyper_parameters['model']['n_classes']
+
+    # model = TransMIL()
+    model = TransMIL(n_classes).to(device)
+    model_weights = checkpoint['state_dict']
+
+    for key in list(model_weights):
+        model_weights[key.replace('model.', '')] = model_weights.pop(key)
+    
+    model.load_state_dict(model_weights)
+
+    count = 0
+    # for m in model.modules():
+    #     if isinstance(m, torch.nn.BatchNorm2d):
+    #         # # m.track_running_stats = False
+    #         # count += 1 #skip the first BatchNorm layer in my ResNet50 based encoder
+    #         # if count >= 2:
+    #             # m.eval()
+    #         print(m)
+    #         m.track_running_stats = False
+    #         m.running_mean = None
+    #         m.running_var = None
+    
+    for param in model.parameters():
+        param.requires_grad = False
+    model.eval()
+
+    home = Path.cwd().parts[1]
+    data_root = f'/{home}/ylan/data/DeepGraft/224_128uM_annotated'
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest.json'
+    dataset = FeatureBagLoader(data_root, label_path=label_path, mode='test', cache=False, n_classes=n_classes)
+
+    dl = DataLoader(dataset, batch_size=1, num_workers=8)
+
+    
+
+    AUROC = torchmetrics.AUROC(num_classes = n_classes)
+
+    start = time.time()
+    test_logits = []
+    test_probs = []
+    test_labels = []
+    data = [{"count": 0, "correct": 0} for i in range(n_classes)]
+
+    for item in tqdm(dl): 
+
+        bag, label, (name, batch_coords, patient) = item
+        # label = label.float()
+        Y = int(label)
+
+        bag = bag.float().to(device)
+        # print(bag.shape)
+        bag = bag.unsqueeze(0)
+        with torch.cuda.amp.autocast():
+            logits = model(bag)
+        Y_hat = torch.argmax(logits, dim=1)
+        Y_prob = F.softmax(logits, dim = 1)
+
+        # print(Y_prob)
+
+        test_logits.append(logits)
+        test_probs.append(Y_prob)
+
+        test_labels.append(label)
+        data[Y]['count'] += 1
+        data[Y]['correct'] += (int(Y_hat) == Y)
+    probs = torch.cat(test_probs).detach().cpu()
+    targets = torch.stack(test_labels).squeeze().detach().cpu()
+    print(probs.shape)
+    print(targets.shape)
+
+    
+    for c in range(n_classes):
+        count = data[c]['count']
+        correct = data[c]['correct']
+        if count == 0:
+            acc = None
+        else: 
+            acc = float(correct) / count
+        print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
+
+
+
+    auroc = AUROC(probs, targets)
+    print(auroc)
+    end = time.time()
+    print('Bag Time: ', end-start)
+
+
+
  
\ No newline at end of file
diff --git a/code/train.py b/code/train.py
index e01bc52c17b0b8b0698124bb4ae0e61d087a585d..53ab165add41af536e18863583798d424df59ec2 100644
--- a/code/train.py
+++ b/code/train.py
@@ -5,7 +5,8 @@ import glob
 
 from sklearn.model_selection import KFold
 
-from datasets.data_interface import DataInterface, MILDataModule, CrossVal_MILDataModule
+from datasets.data_interface import MILDataModule, CrossVal_MILDataModule
+# from datasets.data_interface import DataInterface, MILDataModule, CrossVal_MILDataModule
 from models.model_interface import ModelInterface
 from models.model_interface_dtfd import ModelInterface_DTFD
 import models.vision_transformer as vits
@@ -63,7 +64,9 @@ def make_parse():
     parser.add_argument('--stage', default='train', type=str)
     parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str)
     parser.add_argument('--version', default=2,type=int)
-    parser.add_argument('--gpus', nargs='+', default = [2], type=int)
+    parser.add_argument('--epoch', default='0',type=str)
+
+    parser.add_argument('--gpus', nargs='+', default = [0], type=int)
     parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str)
     parser.add_argument('--fold', default = 0)
     parser.add_argument('--bag_size', default = 1024, type=int)
@@ -78,7 +81,7 @@ def make_parse():
 #---->main
 def main(cfg):
 
-    torch.set_num_threads(16)
+    torch.set_num_threads(8)
 
     #---->Initialize seed
     pl.seed_everything(cfg.General.seed)
@@ -111,6 +114,8 @@ def main(cfg):
                 'n_classes': cfg.Model.n_classes,
                 'bag_size': cfg.Data.bag_size,
                 'use_features': use_features,
+                'mixup': cfg.Data.mixup,
+                'aug': cfg.Data.aug,
                 }
 
     if cfg.Data.cross_val:
@@ -142,7 +147,7 @@ def main(cfg):
             logger=cfg.load_loggers,
             callbacks=cfg.callbacks,
             max_epochs= cfg.General.epochs,
-            min_epochs = 100,
+            min_epochs = 500,
             accelerator='gpu',
             # plugins=plugins,
             devices=cfg.General.gpus,
@@ -156,7 +161,7 @@ def main(cfg):
             # limit_train_batches=1,
             
             # deterministic=True,
-            check_val_every_n_epoch=5,
+            check_val_every_n_epoch=1,
         )
     else:
         trainer = Trainer(
@@ -167,7 +172,7 @@ def main(cfg):
             min_epochs = 100,
 
             # gpus=cfg.General.gpus,
-            accelerator='gpu'
+            accelerator='gpu',
             devices=cfg.General.gpus,
             amp_backend='native',
             # amp_level=cfg.General.amp_level,  
@@ -178,7 +183,7 @@ def main(cfg):
             # limit_train_batches=1,
             
             # deterministic=True,
-            check_val_every_n_epoch=5,
+            check_val_every_n_epoch=1,
         )
     # print(cfg.log_path)
     # print(trainer.loggers[0].log_dir)
@@ -215,18 +220,29 @@ def main(cfg):
         else:
             trainer.fit(model = model, datamodule = dm)
     else:
-        log_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.version}' 
+        log_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.version}'/'checkpoints' 
 
+        print(log_path)
         test_path = Path(log_path) / 'test'
-        for n in range(cfg.Model.n_classes):
-            n_output_path = test_path / str(n)
-            n_output_path.mkdir(parents=True, exist_ok=True)
+        # for n in range(cfg.Model.n_classes):
+        #     n_output_path = test_path / str(n)
+        #     n_output_path.mkdir(parents=True, exist_ok=True)
         # print(cfg.log_path)
         model_paths = list(log_path.glob('*.ckpt'))
-        model_paths = [str(model_path) for model_path in model_paths if 'epoch' in str(model_path)]
+        # print(model_paths)
+        # print(cfg.epoch)
+        # model_paths = [str(model_path) for model_path in model_paths if 'epoch' in str(model_path)]
+        if cfg.epoch == 'last':
+            model_paths = [str(model_path) for model_path in model_paths if f'last' in str(model_path)]
+        elif int(cfg.epoch) < 10:
+            cfg.epoch = f'0{cfg.epoch}'
+        
+        else:
+            model_paths = [str(model_path) for model_path in model_paths if f'epoch={cfg.epoch}' in str(model_path)]
         # model_paths = [f'{log_path}/epoch=279-val_loss=0.4009.ckpt']
+
         for path in model_paths:
-            print(path)
+            # print(path)
             new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg)
             trainer.test(model=new_model, datamodule=dm)
 
@@ -257,6 +273,7 @@ def check_home(cfg):
 if __name__ == '__main__':
 
     args = make_parse()
+
     cfg = read_yaml(args.config)
 
     #---->update
@@ -283,7 +300,10 @@ if __name__ == '__main__':
     cfg.task = task
     # task = Path(cfg.config).name[:-5].split('_')[2:][0]
     cfg.log_path = log_path / f'{cfg.Model.name}' / task / log_name 
-    
+
+
+
+    cfg.epoch = args.epoch
     
 
     # ---->main
diff --git a/code/utils/__pycache__/utils.cpython-39.pyc b/code/utils/__pycache__/utils.cpython-39.pyc
index 26de104113838d7e9f56ecbe311a3c8813d08d77..df4436cad807f2a148eddd27875e344eb4dcbe14 100644
Binary files a/code/utils/__pycache__/utils.cpython-39.pyc and b/code/utils/__pycache__/utils.cpython-39.pyc differ
diff --git a/code/utils/utils.py b/code/utils/utils.py
index 5596208e41716da17ac42e18741e09b99ff93667..9a756d58876ee847677f414a99318cf426fcf75d 100755
--- a/code/utils/utils.py
+++ b/code/utils/utils.py
@@ -14,9 +14,11 @@ from pytorch_lightning import LightningModule
 from pytorch_lightning.loops.base import Loop
 from pytorch_lightning.loops.fit_loop import FitLoop
 from pytorch_lightning.trainer.states import TrainerFn
+from pytorch_lightning.callbacks import LearningRateMonitor
 from typing import Any, Dict, List, Optional, Type
 import shutil
 
+
 #---->read yaml
 import yaml
 from addict import Dict
@@ -103,7 +105,7 @@ def load_callbacks(cfg, save_path):
         # save_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.resume_version}' / last.ckpt
         Mycallbacks.append(ModelCheckpoint(monitor = 'val_loss',
                                          dirpath = str(output_path),
-                                         filename = '{epoch:02d}-{val_loss:.4f}-{val_auc: .4f}',
+                                         filename = '{epoch:02d}-{val_loss:.4f}-{val_auc: .4f}-{val_patient_auc}',
                                          verbose = True,
                                          save_last = True,
                                          save_top_k = 2,
@@ -111,7 +113,15 @@ def load_callbacks(cfg, save_path):
                                          save_weights_only = True))
         Mycallbacks.append(ModelCheckpoint(monitor = 'val_auc',
                                          dirpath = str(output_path),
-                                         filename = '{epoch:02d}-{val_loss:.4f}-{val_auc:.4f}',
+                                         filename = '{epoch:02d}-{val_loss:.4f}-{val_auc:.4f}-{val_patient_auc}',
+                                         verbose = True,
+                                         save_last = True,
+                                         save_top_k = 2,
+                                         mode = 'max',
+                                         save_weights_only = True))
+        Mycallbacks.append(ModelCheckpoint(monitor = 'val_patient_auc',
+                                         dirpath = str(output_path),
+                                         filename = '{epoch:02d}-{val_loss:.4f}-{val_auc:.4f}-{val_patient_auc}',
                                          verbose = True,
                                          save_last = True,
                                          save_top_k = 2,
@@ -121,6 +131,9 @@ def load_callbacks(cfg, save_path):
     swa = StochasticWeightAveraging(swa_lrs=1e-2)
     Mycallbacks.append(swa)
 
+    lr_monitor = LearningRateMonitor(logging_interval='step')
+    Mycallbacks.append(lr_monitor)
+
     return Mycallbacks
 
 #---->val loss
@@ -128,7 +141,8 @@ import torch
 import torch.nn.functional as F
 def cross_entropy_torch(x, y):
     x_softmax = [F.softmax(x[i], dim=0) for i in range(len(x))]
-    x_log = torch.tensor([torch.log(x_softmax[i][y[i]]) for i in range(y.shape[0])])
+    x_log = torch.tensor([torch.log(x_softmax[i][y[i]]) for i in range(len(y))])
+    # x_log = torch.tensor([torch.log(x_softmax[i][y[i]]) for i in range(y.shape[0])])
     loss = - torch.sum(x_log) / y.shape[0]
     return loss
 
diff --git a/monai_test.json b/monai_test.json
new file mode 100644
index 0000000000000000000000000000000000000000..093623db790c7bbb554ac717a04138ab5ac53d24
--- /dev/null
+++ b/monai_test.json
@@ -0,0 +1 @@
+{"training": [{"image": "Aachen_KiBiDatabase_KiBiAcRCIQ360_01_018_PAS.svs", "label": 0}, {"image": "Aachen_KiBiDatabase_KiBiAcRCIQ360_01_018_PAS.svs", "label": 0}, {"image": "Aachen_KiBiDatabase_KiBiAcRCIQ360_01_018_PAS.svs", "label": 0}], "validation": [{"image": "Aachen_KiBiDatabase_KiBiAcRCIQ360_01_018_PAS.svs", "label": 0}]}
\ No newline at end of file
diff --git a/paper_structure.md b/paper_structure.md
new file mode 100644
index 0000000000000000000000000000000000000000..c6747192ab2aba827b9a3f787bbf5f1b169f9990
--- /dev/null
+++ b/paper_structure.md
@@ -0,0 +1,35 @@
+# Paper Outline
+
+## Abstract
+
+## Introduction
+
+## Methods
+
+    - Fig 1: Model/Workflow
+
+## Dataset
+
+    - Fig: cohorts, data selection
+    - Fig: Preprocessing
+
+## Results
+
+    - Fig: Metrics on Testset for each task:
+
+| Model        | Accuracy | Precision | Recall | AUROC |
+| ------------ | -------- | --------- | ------ | ----- |
+| Resnet18     |          |           |        |       |
+| ViT          |          |           |        |       |
+| CLAM         |          |           |        |       |
+| AttentionMIL |          |           |        |       |
+| TransMIL     |          |           |        |       |
+|              |          |           |        |       |
+
+    - Fig: AUROC Curves (Best Model, Rest in Appendix)
+
+    - Fig: Attention Maps (Best Model, Rest in Appendix)
+
+## Discussion
+
+## Appendix
diff --git a/project_plan.md b/project_plan.md
new file mode 100644
index 0000000000000000000000000000000000000000..b0e8379c0f5872e1e70d9095c347ab37d9fc19da
--- /dev/null
+++ b/project_plan.md
@@ -0,0 +1,102 @@
+#   Benchmarking weakly supervised deep learning models for transplant pathology classification
+
+With this project, we aim to esatablish a benchmark for weakly supervised deep learning models for transplant pathology classification, especially for multiple instance learning approaches. 
+
+
+## Cohorts:
+
+#### Original Lancet Set:
+
+    * Training:
+        * AMS: 1130 Biopsies (3390 WSI)
+        * Utrecht: 717 Biopsies (2151WSI)
+    * Testing:
+        * Aachen: 101 Biopsies (303 WSI)
+
+
+#### Extended:
+
+* Training:
+  * AMS + Utrecht + Leuven
+* Testing:
+  * Aachen_extended:
+
+## Models:
+
+    For our Benchmark, we chose the following models: 
+
+    - AttentionMIL
+    - Resnet18/50
+    - ViT
+    - CLAM
+    - TransMIL
+    - Monai MIL (optional)
+
+    Resnet18 and Resnet50 are basic CNNs that can be applied for a variety of tasks. Although domain or task specific architectures mostly outperform them, they remain a good baseline for comparison. 
+
+    The vision transformer is the first transformer based model that was adapted to computer vision tasks. Benchmarking on ViT can provide more insight on the performance of generic transformer based models on multiple instance learning. 
+
+    The AttentionMIL was the first simple, yet relatively successful deep MIL model and should be used as a baseline for benchmarking MIL methods. 
+
+    CLAM is a recent model proposed by Mahmood lab which was explicitely trained for histopathological whole slide images and should be used as a baseline for benchmarking MIL methods in histopathology. 
+
+    TransMIL is another model proposed by Shao et al, which achieved SOTA on histopathological WSI classification tasks using MIL. It was benchmarked on TCGA and compared to CLAM and AttMIL. It utilizes the self-attention module from transformer models.
+
+    Monai MIL (not official name) is a MIL architecture proposed by Myronenk et al (Nvidia). It applies the self-attention mechanism as well. It is included because it shows promising results and it's included in MONAI. 
+
+## Tasks:
+
+    The Original tasks mimic the ones published in the original DeepGraft Lancet paper. 
+    Before we go for more challenging tasks (future tasks), we want to establish that our models outperform the simpler approach from the previous paper and that going for MIL in this setting is indeed profitable. 
+
+    All available classes: 
+        * Normal
+        * TCMR
+        * ABMR
+        * Mixed
+        * Viral
+        * Other
+
+#### Original:
+
+    The explicit classes are simplified/grouped together such as this: 
+    Diseased = all classes other than Normal 
+    Rejection = TCMR, ABMR, Mixed 
+
+    - (1) Normal vs Diseased (all other classes)
+    - (2) Rejection vs (Viral + Others)
+    - (3) Normal vs Rejection vs (Viral + Others)
+
+#### Future:
+
+    After validating Original tasks, the next step is to challenge the models by attempting more complicated tasks. 
+    These experiments may vary depending on the results from previous experiments
+
+    - (4) Normal vs TCMR vs Mixed vs ABMR vs Viral vs Others
+    - (5) TCMR vs Mixed vs ABMR
+
+## Plan:
+
+    1. Train models for current tasks on AMS+Utrecht -> Validate on Aachen
+
+    2. Visualization, AUC Curves
+
+    3. Train best model on extended training set (AMS+Utrecht+Leuven) (Tasks 1,2,3) -> Validate on Aachen_extended
+        - Investigate if a larger training cohort increases performance
+    4. Train best model on extended dataset on future tasks (Task 4, 5)
+
+
+    Notes: 
+        * Resnet18, ViT and CLAM are all trained on HIA (Training Framework from Kather / Narmin)
+    
+
+## Status: 
+
+        - Resnet18: Trained on all tasks via HIA  
+        - Vit: Trained on all tasks via HIA 
+        - CLAM: Trained on (1) via HIA 
+        - TransMIL: Trained, but overfitting
+            - Check if the problems are not on model side by evaluating on RCC data. 
+            - (mixing in 10 slides from Aachen increases auc performance from 0.7 to 0.89)
+        - AttentionMIL: WIP
+        - Monai MIL: WIP