diff --git a/.gitignore b/.gitignore
index c9e4fdabe0660f6872b93a3d8a59aa1750a671e0..104287f888d356d4b44e435e944beea70047797c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,5 @@
 logs/*
 lightning_logs/*
-test/*
\ No newline at end of file
+test/*
+DeepGraft_Project_Plan_12.7.22.pdf
+monai_test.json
diff --git a/DeepGraft/Inception_norm_rest.yaml b/DeepGraft/Inception_norm_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..48cd4eb8a366bb53529eab6666ac82e23f611950
--- /dev/null
+++ b/DeepGraft/Inception_norm_rest.yaml
@@ -0,0 +1,55 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16
+    multi_gpu_mode: ddp
+    gpus: [0, 1]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: train #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    mixup: False
+    aug: True
+    cache: False
+    data_dir: '/home/ylan/data/DeepGraft/224_256uM_annotated/'
+    label_file: '/home/ylan/data/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest_val_1.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 500 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: inception
+    n_classes: 2
+    backbone: inception
+    in_features: 2048
+    out_features: 1024
+
+
+Optimizer:
+    opt: adam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/Resnet18_img_norm_rest.yaml b/DeepGraft/Resnet18_img_norm_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d2183e414ca19ccc60a154d72898c64134598464
--- /dev/null
+++ b/DeepGraft/Resnet18_img_norm_rest.yaml
@@ -0,0 +1,55 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16
+    multi_gpu_mode: ddp
+    gpus: [0, 1]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: train #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    mixup: True
+    aug: True
+    cache: False
+    data_dir: '/home/ylan/data/DeepGraft/224_256uM_annotated/'
+    label_file: '/home/ylan/data/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest_val_1.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1000 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: resnet18
+    n_classes: 2
+    backbone: resnet18
+    in_features: 2048
+    out_features: 1024
+
+
+Optimizer:
+    opt: Adam
+    lr: 0.0001
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft/Resnet50.yaml b/DeepGraft/Resnet50.yaml
deleted file mode 100644
index e6b780b18947b315ee77bb0c6b8b0dbd433fa249..0000000000000000000000000000000000000000
--- a/DeepGraft/Resnet50.yaml
+++ /dev/null
@@ -1,49 +0,0 @@
-General:
-    comment: 
-    seed: 2021
-    fp16: True
-    amp_level: O2
-    precision: 16 
-    multi_gpu_mode: dp
-    gpus: [0]
-    epochs: &epoch 200 
-    grad_acc: 2
-    frozen_bn: False
-    patience: 200
-    server: test #train #test
-    log_path: logs/
-
-Data:
-    dataset_name: custom
-    data_shuffle: False
-    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
-    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_bin.json'
-    fold: 0
-    nfold: 4
-
-    train_dataloader:
-        batch_size: 1 
-        num_workers: 8
-
-    test_dataloader:
-        batch_size: 1
-        num_workers: 8
-
-        
-
-Model:
-    name: resnet50
-    n_classes: 2
-
-
-Optimizer:
-    opt: lookahead_radam
-    lr: 0.0002
-    opt_eps: null 
-    opt_betas: null
-    momentum: null 
-    weight_decay: 0.00001
-
-Loss:
-    base_loss: CrossEntropyLoss
-
diff --git a/DeepGraft/TransMIL_feat_norm_rej_rest.yaml b/DeepGraft/TransMIL_feat_norm_rej_rest.yaml
index 3dcc81775c0741d3f9cf02b8b501349bb70323c8..3a9886f7945782d4ef9c665b5d72dc3dfd116503 100644
--- a/DeepGraft/TransMIL_feat_norm_rej_rest.yaml
+++ b/DeepGraft/TransMIL_feat_norm_rej_rest.yaml
@@ -3,27 +3,30 @@ General:
     seed: 2021
     fp16: True
     amp_level: O2
-    precision: 16 
-    multi_gpu_mode: dp
-    gpus: [0]
+    precision: 32
+    multi_gpu_mode: ddp
+    gpus: [0, 1]
     epochs: &epoch 1000 
     grad_acc: 2
     frozen_bn: False
-    patience: 100
-    server: test #train #test
+    patience: 300
+    server: train #train #test
     log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
 
 Data:
     dataset_name: custom
     data_shuffle: False
-    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
-    label_file: '/home/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rej_rest.json'
+    mixup: True
+    aug: True
+    cache: False
+    data_dir: '/home/ylan/data/DeepGraft/224_256uM_annotated/'
+    label_file: '/home/ylan/data/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rej_rest_val_1.json'
     fold: 1
     nfold: 3
     cross_val: False
 
     train_dataloader:
-        batch_size: 1 
+        batch_size: 100
         num_workers: 4
 
     test_dataloader:
@@ -34,13 +37,13 @@ Model:
     name: TransMIL
     n_classes: 3
     backbone: features
-    in_features: 512
+    in_features: 2048
     out_features: 512
 
 
 Optimizer:
     opt: lookahead_radam
-    lr: 0.0002
+    lr: 0.002
     opt_eps: null 
     opt_betas: null
     momentum: null 
diff --git a/DeepGraft/TransMIL_feat_norm_rest.yaml b/DeepGraft/TransMIL_feat_norm_rest.yaml
index ea452a34b4d1a4ba83aa669486def30b9e5bb12e..1651aa436e4feb72d282d90169a6c34c8acdeb57 100644
--- a/DeepGraft/TransMIL_feat_norm_rest.yaml
+++ b/DeepGraft/TransMIL_feat_norm_rest.yaml
@@ -4,9 +4,9 @@ General:
     fp16: True
     amp_level: O2
     precision: 16
-    multi_gpu_mode: dp
+    multi_gpu_mode: ddp
     gpus: [0, 1]
-    epochs: &epoch 500 
+    epochs: &epoch 1000 
     grad_acc: 2
     frozen_bn: False
     patience: 50
@@ -16,16 +16,17 @@ General:
 Data:
     dataset_name: custom
     data_shuffle: False
-    mixup: False
+    mixup: True
     aug: True
-    data_dir: '/home/ylan/data/DeepGraft/224_128uM_annotated/'
-    label_file: '/home/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest.json'
+    cache: False
+    data_dir: '/home/ylan/data/DeepGraft/224_256uM_annotated/'
+    label_file: '/home/ylan/data/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest_val_1.json'
     fold: 1
     nfold: 3
     cross_val: False
 
     train_dataloader:
-        batch_size: 1 
+        batch_size: 100
         num_workers: 4
 
     test_dataloader:
@@ -37,12 +38,12 @@ Model:
     n_classes: 2
     backbone: features
     in_features: 2048
-    out_features: 1024
+    out_features: 512
 
 
 Optimizer:
-    opt: Adam
-    lr: 0.0001
+    opt: lookahead_radam
+    lr: 0.002
     opt_eps: null 
     opt_betas: null
     momentum: null 
diff --git a/DeepGraft/TransMIL_feat_rej_rest.yaml b/DeepGraft/TransMIL_feat_rej_rest.yaml
index ca9c0e47e15a2588a2fee91529653f2a7c07c735..4a054f25e9ccebb0f935464889c2b3d0bbdf8be5 100644
--- a/DeepGraft/TransMIL_feat_rej_rest.yaml
+++ b/DeepGraft/TransMIL_feat_rej_rest.yaml
@@ -3,44 +3,47 @@ General:
     seed: 2021
     fp16: True
     amp_level: O2
-    precision: 16 
-    multi_gpu_mode: dp
-    gpus: [0]
-    epochs: &epoch 500 
+    precision: 16
+    multi_gpu_mode: ddp
+    gpus: [0, 1]
+    epochs: &epoch 1000 
     grad_acc: 2
     frozen_bn: False
     patience: 50
-    server: test #train #test
+    server: train #train #test
     log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
 
 Data:
     dataset_name: custom
     data_shuffle: False
-    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
-    label_file: '/home/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_rej_rest.json'
+    mixup: True
+    aug: True
+    cache: False
+    data_dir: '/home/ylan/data/DeepGraft/224_256uM_annotated/'
+    label_file: '/home/ylan/data/DeepGraft/training_tables/dg_split_PAS_HE_Jones_rej_rest_val_1.json'
     fold: 1
     nfold: 3
     cross_val: False
 
     train_dataloader:
-        batch_size: 1 
-        num_workers: 8
+        batch_size: 100
+        num_workers: 4
 
     test_dataloader:
         batch_size: 1
-        num_workers: 8
+        num_workers: 4
 
 Model:
     name: TransMIL
     n_classes: 2
     backbone: features
-    in_features: 1024
+    in_features: 2048
     out_features: 512
 
 
 Optimizer:
     opt: lookahead_radam
-    lr: 0.0002
+    lr: 0.002
     opt_eps: null 
     opt_betas: null
     momentum: null 
diff --git a/DeepGraft/TransMIL_retccl_norm_rest.yaml b/DeepGraft/TransMIL_retccl_norm_rest.yaml
index fa9988de796bf5e7f2d8ce418ccac95b6d28863e..9b04677a7402218b2d008a34385c1664dbc1ac0b 100644
--- a/DeepGraft/TransMIL_retccl_norm_rest.yaml
+++ b/DeepGraft/TransMIL_retccl_norm_rest.yaml
@@ -16,14 +16,14 @@ General:
 Data:
     dataset_name: custom
     data_shuffle: False
-    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
-    label_file: '/home/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest.json'
+    data_dir: '/home/ylan/data/DeepGraft/224_256uM_annotated/'
+    label_file: '/home/ylan/data/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest_img_val_1.json'
     fold: 1
     nfold: 3
     cross_val: False
 
     train_dataloader:
-        batch_size: 1 
+        batch_size: 5 
         num_workers: 4
 
     test_dataloader:
@@ -34,8 +34,8 @@ Model:
     name: TransMIL
     n_classes: 2
     backbone: retccl
-    in_features: 512
-    out_features: 1024
+    in_features: 2048
+    out_features: 512
 
 
 Optimizer:
diff --git a/DeepGraft/TransformerMIL_feat_norm_rest.yaml b/DeepGraft/TransformerMIL_feat_norm_rest.yaml
index 7c90fbff0e56441623ff2ae68cf791776cf7d589..2f86ef11fff79b9138465fa3b762656b8d3a386a 100644
--- a/DeepGraft/TransformerMIL_feat_norm_rest.yaml
+++ b/DeepGraft/TransformerMIL_feat_norm_rest.yaml
@@ -10,14 +10,14 @@ General:
     grad_acc: 2
     frozen_bn: False
     patience: 100
-    server: test #train #test
+    server: train #train #test
     log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
 
 Data:
     dataset_name: custom
     data_shuffle: False
-    data_dir: '/home/ylan/data/DeepGraft/224_128um_v2/'
-    label_file: '/home/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest.json'
+    data_dir: '/home/ylan/data/DeepGraft/224_256uM_annotated/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest_val_1.json'
     fold: 1
     nfold: 3
     cross_val: False
@@ -34,8 +34,8 @@ Model:
     name: TransformerMIL
     n_classes: 2
     backbone: features
-    in_features: 512
-    out_features: 1024
+    in_features: 2048
+    out_features: 512
 
 
 Optimizer:
diff --git a/DeepGraft/Vit_norm_rest.yaml b/DeepGraft/Vit_norm_rest.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ae4a0e5c8342f935335eb3d765f96dac199135fe
--- /dev/null
+++ b/DeepGraft/Vit_norm_rest.yaml
@@ -0,0 +1,55 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16
+    multi_gpu_mode: ddp
+    gpus: [0, 1]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: train #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    mixup: True
+    aug: True
+    cache: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128uM_annotated/'
+    label_file: '/home/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest_val_1.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 500 
+        num_workers: 4
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 4
+
+Model:
+    name: vit
+    n_classes: 2
+    backbone: vit
+    in_features: 2048
+    out_features: 1024
+
+
+Optimizer:
+    opt: Adam
+    lr: 0.0001
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/DeepGraft_Project_Plan.pdf b/DeepGraft_Project_Plan.pdf
deleted file mode 100644
index 1b5a0c95fca1a588080f74941d9141b5b45bb604..0000000000000000000000000000000000000000
Binary files a/DeepGraft_Project_Plan.pdf and /dev/null differ
diff --git a/code/MyLoss/__pycache__/loss_factory.cpython-39.pyc b/code/MyLoss/__pycache__/loss_factory.cpython-39.pyc
index 5d4cc8bb3818896e5944e5eb2f6b551388b67e48..7273dc294c45323d033079fc36e6b2181013059d 100644
Binary files a/code/MyLoss/__pycache__/loss_factory.cpython-39.pyc and b/code/MyLoss/__pycache__/loss_factory.cpython-39.pyc differ
diff --git a/code/MyLoss/loss_factory.py b/code/MyLoss/loss_factory.py
index f17f69804e81667ee796e7e509803ea5735829e4..912533349dc5db1172066e0bcd5219d5cbd1a62a 100755
--- a/code/MyLoss/loss_factory.py
+++ b/code/MyLoss/loss_factory.py
@@ -27,10 +27,10 @@ def create_loss(args, n_classes, w1=1.0, w2=0.5):
     ### MulticlassJaccardLoss(classes=np.arange(11)
     # mode = args.base_loss #BINARY_MODE \MULTICLASS_MODE \MULTILABEL_MODE 
     loss = None
-    print(conf_loss)
+    # print(conf_loss)
     if hasattr(nn, conf_loss): 
         loss = getattr(nn, conf_loss)()
-        # loss = getattr(nn, conf_loss)(label_smoothing=0.5) 
+        # loss = getattr(nn, conf_loss)(label_smoothing=0.1) 
     #binary loss
     elif conf_loss == "focal":
         loss = FocalLoss_Ori(n_classes)
diff --git a/code/__pycache__/test_visualize.cpython-39.pyc b/code/__pycache__/test_visualize.cpython-39.pyc
index c3a94d6f97af78fa95e1263a76b8e60928a923d5..e06f73a4c5069d526f1bbf4449d7223a191e5e73 100644
Binary files a/code/__pycache__/test_visualize.cpython-39.pyc and b/code/__pycache__/test_visualize.cpython-39.pyc differ
diff --git a/code/cufile.log b/code/cufile.log
new file mode 100644
index 0000000000000000000000000000000000000000..d9aee1a5053fa16c26a24e59ed6a06f56af1f035
--- /dev/null
+++ b/code/cufile.log
@@ -0,0 +1,6 @@
+ 21-12-2022 16:48:14:373 [pid=1690629 tid=1690629] NOTICE  cufio-drv:625 running in compatible mode
+ 22-12-2022 10:31:41:400 [pid=1904890 tid=1904890] NOTICE  cufio-drv:625 running in compatible mode
+ 22-12-2022 10:52:13:216 [pid=1909914 tid=1909914] NOTICE  cufio-drv:625 running in compatible mode
+ 22-12-2022 11:02:15:996 [pid=1912278 tid=1912278] NOTICE  cufio-drv:625 running in compatible mode
+ 22-12-2022 11:15:17:212 [pid=1915495 tid=1915495] NOTICE  cufio-drv:625 running in compatible mode
+ 02-01-2023 00:11:43:868 [pid=931838 tid=931838] NOTICE  cufio-drv:625 running in compatible mode
diff --git a/code/datasets/__init__.py b/code/datasets/__init__.py
index 2989858e6e652de44eaa34b0eb5ba798f1ffefdf..4f1906435ca47d707bc4841563557f8a419af3fb 100644
--- a/code/datasets/__init__.py
+++ b/code/datasets/__init__.py
@@ -1,4 +1,4 @@
-
-from .custom_jpg_dataloader import JPGMILDataloader
+# from .custom_jpg_dataloader import JPGMILDataloader
+from .jpg_dataloader import JPGMILDataloader
 from .data_interface import MILDataModule
 from .fast_tensor_dl import FastTensorDataLoader
diff --git a/code/datasets/__pycache__/__init__.cpython-39.pyc b/code/datasets/__pycache__/__init__.cpython-39.pyc
index d67531a6e6443d08e896d436a6b2ba379fe43528..72bc845bf4c597bb8c39e8bd4a53b0601a654ebc 100644
Binary files a/code/datasets/__pycache__/__init__.cpython-39.pyc and b/code/datasets/__pycache__/__init__.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/classic_jpg_dataloader.cpython-39.pyc b/code/datasets/__pycache__/classic_jpg_dataloader.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..56a80385988c3216122973acc207f2a9ffeb76e6
Binary files /dev/null and b/code/datasets/__pycache__/classic_jpg_dataloader.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/data_interface.cpython-39.pyc b/code/datasets/__pycache__/data_interface.cpython-39.pyc
index e1151f291a6f1bd56821ff01a34f3019adc21b35..798fdf96f92b70a082697e2b35225649b461b9e4 100644
Binary files a/code/datasets/__pycache__/data_interface.cpython-39.pyc and b/code/datasets/__pycache__/data_interface.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/feature_dataloader.cpython-39.pyc b/code/datasets/__pycache__/feature_dataloader.cpython-39.pyc
index 10319c23951eae75c10e1a3f2836050359059030..60ef98c89f08368cb1252c836eccbf99d987cfbf 100644
Binary files a/code/datasets/__pycache__/feature_dataloader.cpython-39.pyc and b/code/datasets/__pycache__/feature_dataloader.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/jpg_dataloader.cpython-39.pyc b/code/datasets/__pycache__/jpg_dataloader.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32985c0a0392e944a17373474d4eb59dea3f7a71
Binary files /dev/null and b/code/datasets/__pycache__/jpg_dataloader.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/myTransforms.cpython-39.pyc b/code/datasets/__pycache__/myTransforms.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..76f721cf88d4e4a13955f2aacc4747eed8fbd2fc
Binary files /dev/null and b/code/datasets/__pycache__/myTransforms.cpython-39.pyc differ
diff --git a/code/datasets/__pycache__/simple_jpg_dataloader.cpython-39.pyc b/code/datasets/__pycache__/simple_jpg_dataloader.cpython-39.pyc
index 782f8d99cad131fbd9c433fc963873773108d79b..6cc398072d0abb4066cf4db1cd155d247b46928f 100644
Binary files a/code/datasets/__pycache__/simple_jpg_dataloader.cpython-39.pyc and b/code/datasets/__pycache__/simple_jpg_dataloader.cpython-39.pyc differ
diff --git a/code/datasets/classic_jpg_dataloader.py b/code/datasets/classic_jpg_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c98e52cc1f8b4c78ea1f51b8e724028bac13f7d
--- /dev/null
+++ b/code/datasets/classic_jpg_dataloader.py
@@ -0,0 +1,336 @@
+# import pandas as pd
+
+import numpy as np
+import torch
+from torch import Tensor
+from torch.utils import data
+from torch.utils.data import random_split, DataLoader
+from torch.autograd import Variable
+from torch.nn.functional import one_hot
+import torch.utils.data as data_utils
+import torchvision.transforms as transforms
+import pandas as pd
+from sklearn.utils import shuffle
+from pathlib import Path
+from tqdm import tqdm
+from PIL import Image
+import cv2
+import json
+from imgaug import augmenters as iaa
+from torchsampler import ImbalancedDatasetSampler
+from .utils import myTransforms
+
+
+class JPGBagLoader(data_utils.Dataset):
+    def __init__(self, file_path, label_path, mode, n_classes, data_cache_size=100, max_bag_size=1000, cache=False, mixup=False, aug=False, model='inception'):
+        super().__init__()
+
+        self.data_info = []
+        self.data_cache = {}
+        self.slideLabelDict = {}
+        self.files = []
+        self.data_cache_size = data_cache_size
+        self.mode = mode
+        self.file_path = file_path
+        # self.csv_path = csv_path
+        self.label_path = label_path
+        self.n_classes = n_classes
+        self.max_bag_size = max_bag_size
+        self.min_bag_size = 50
+        self.empty_slides = []
+        self.corrupt_slides = []
+        self.cache = False
+        self.labels = []
+        if model == 'inception':
+            size = 299
+        elif model == 'vit':
+            size = 384
+        else: size = 224
+
+        
+        # read labels and slide_path from csv
+        with open(self.label_path, 'r') as f:
+            json_dict = json.load(f)
+            temp_slide_label_dict = json_dict[self.mode]
+            # print(len(temp_slide_label_dict))
+            for (x,y) in temp_slide_label_dict:
+                x = x.replace('FEATURES_RETCCL_2048', 'BLOCKS')
+                # print(x)
+                x_name = Path(x).stem
+                x_path_list = [Path(self.file_path)/x]
+                for x_path in x_path_list:
+                    if x_path.exists():
+                        # print(len(list(x_path.glob('*'))))
+
+                        self.slideLabelDict[x_name] = y
+                        self.labels += [int(y)]*len(list(x_path.glob('*')))
+                        # self.labels.append(int(y))
+                        for patch in x_path.iterdir():
+                            self.files.append((patch, x_name, y))
+
+        # with open(self.label_path, 'r') as f:
+        #     temp_slide_label_dict = json.load(f)[mode]
+        #     print(len(temp_slide_label_dict))
+        #     for (x, y) in temp_slide_label_dict:
+        #         x = Path(x).stem 
+        #         # x_complete_path = Path(self.file_path)/Path(x)
+        #         for cohort in Path(self.file_path).iterdir():
+        #             x_complete_path = Path(self.file_path) / cohort / 'BLOCKS' / Path(x)
+        #             if x_complete_path.is_dir():
+        #                 if len(list(x_complete_path.iterdir())) > self.min_bag_size:
+        #                 # print(x_complete_path)
+        #                     self.slideLabelDict[x] = y
+        #                     self.files.append(x_complete_path)
+        #                 else: self.empty_slides.append(x_complete_path)
+        
+        home = Path.cwd().parts[1]
+        self.slide_patient_dict_path = Path(self.label_path).parent / 'slide_patient_dict_an.json'
+        # self.slide_patient_dict_path = f'/{home}/ylan/data/DeepGraft/training_tables/slide_patient_dict_an.json'
+        with open(self.slide_patient_dict_path, 'r') as f:
+            self.slide_patient_dict = json.load(f)
+
+
+        self.color_transforms = myTransforms.Compose([
+            myTransforms.ColorJitter(
+                brightness = (0.65, 1.35), 
+                contrast = (0.5, 1.5),
+                # saturation=(0, 2), 
+                # hue=0.3,
+                ),
+            # myTransforms.RandomChoice([myTransforms.ColorJitter(saturation=(0, 2), hue=0.3),
+            #                             myTransforms.HEDJitter(theta=0.05)]),
+            myTransforms.HEDJitter(theta=0.005),
+            
+        ])
+        self.color_transforms = myTransforms.Compose([
+            myTransforms.Grayscale(num_output_channels=3)
+        ])
+        self.train_transforms = myTransforms.Compose([
+            myTransforms.RandomChoice([myTransforms.RandomHorizontalFlip(p=0.5),
+                                        myTransforms.RandomVerticalFlip(p=0.5),
+                                        myTransforms.AutoRandomRotation()]),
+        
+            myTransforms.RandomGaussBlur(radius=[0.5, 1.5]),
+            myTransforms.RandomAffineCV2(alpha=0.1),
+            myTransforms.RandomElastic(alpha=2, sigma=0.06),
+        ])
+
+        self.resize_transforms = transforms.Resize((299,299), transforms.InterpolationMode.BICUBIC)
+
+        # sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1")
+        # sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2")
+        # sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3")
+        # sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4")
+        # sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5")
+
+        # self.resize_transforms = iaa.Sequential([
+        #     iaa.Resize({'height': size, 'width': size}),
+        #     # iaa.Resize({'height': 299, 'width': 299}),
+        # ], name='resizeAug')
+        # # self.resize_transforms = transforms.Resize(size=(299,299))
+
+        # self.train_transforms = iaa.Sequential([
+        #     iaa.AddToHueAndSaturation(value=(-30, 30), name="MyHSV"), #13
+        #     sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")),
+        #     iaa.Fliplr(0.5, name="MyFlipLR"),
+        #     iaa.Flipud(0.5, name="MyFlipUD"),
+        #     sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")),
+        #     # iaa.OneOf([
+        #     #     sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")),
+        #     #     sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")),
+        #     #     sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine"))
+        #     # ], name="MyOneOf")
+
+        # ], name="MyAug")
+        self.val_transforms = transforms.Compose([
+            # 
+            transforms.ToTensor(),
+            transforms.Normalize(
+                mean=[0.485, 0.456, 0.406],
+                std=[0.229, 0.224, 0.225],
+            ),
+            # RangeNormalization(),
+        ])
+
+
+
+
+        
+
+    def get_data(self, query):
+        
+        patch_path, wsi_name, label = query
+
+        # img = np.asarray(Image.open(patch_path)).astype(np.uint8)
+        img = Image.open(patch_path)
+        # img = np.moveaxis(img, 2, 0)
+        # print(img.shape)
+        # img = torch.from_numpy(img)
+        tile_name = Path(patch_path).stem
+        # patient = tile_name.rsplit('_', 1)[0]
+        patient = self.slide_patient_dict[wsi_name]
+
+        # for tile_path in Path(file_path).iterdir():
+        #     img = np.asarray(Image.open(tile_path)).astype(np.uint8)
+        #     img = np.moveaxis(img, 2, 0)
+        #     # print(img.shape)
+        #     img = torch.from_numpy(img)
+        #     wsi_batch.append(img)
+        #     name_batch.append(tile_path.stem)
+
+        # wsi_batch = torch.stack(wsi_batch)
+        return img, label, (wsi_name, tile_name, patient)
+    
+    def get_labels(self, indices):
+        return [self.labels[i] for i in indices]
+
+
+    def to_fixed_size_bag(self, bag, bag_size: int = 512):
+
+        #duplicate bag instances unitl 
+
+        bag_idxs = torch.randperm(bag.shape[0])[:bag_size]
+        bag_samples = bag[bag_idxs]
+        # name_samples = [names[i] for i in bag_idxs]
+
+        # bag_sample_names = [bag_names[i] for i in bag_idxs]
+        # q, r  = divmod(bag_size, bag_samples.shape[0])
+        # if q > 0:
+        #     bag_samples = torch.cat([bag_samples]*q, 0)
+
+        # self_padded = torch.cat([bag_samples, bag_samples[:r,:, :, :]])
+
+        # zero-pad if we don't have enough samples
+        zero_padded = torch.cat((bag_samples,
+                                torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+
+        return zero_padded, min(bag_size, len(bag))
+
+    def data_dropout(self, bag, drop_rate):
+        bag_size = bag.shape[0]
+        bag_idxs = torch.randperm(bag_size)[:int(bag_size*(1-drop_rate))]
+        bag_samples = bag[bag_idxs]
+        # name_samples = [batch_names[i] for i in bag_idxs]
+
+        return bag_samples
+
+    def __len__(self):
+        return len(self.files)
+
+    def __getitem__(self, index):
+
+    
+        if self.cache:
+            label = self.labels[index]
+            wsi = self.features[index]
+            label = int(label)
+            wsi_name = self.wsi_names[index]
+            tile_name = self.name_batches[index]
+            patient = self.patients[index]
+            # feats = Variable(Tensor(feats))
+            return wsi, label, (wsi_name, tile_name, patient)
+        else:
+            t = self.files[index]
+            # label = self.labels[index]
+            if self.mode=='train':
+                # t = self.files[index]
+                # label = self.labels[index]
+                img, label, (wsi_name, tile_name, patient) = self.get_data(t)
+                save_img(img, f'{tile_name}_original')
+                img = self.resize_transforms(img)
+                img = self.color_transforms(img)
+                img = self.train_transforms(img)
+
+                # save_img(img, f'{tile_name}')
+
+                img = self.val_transforms(img.copy())
+
+                
+                # ft = ft.view(-1, 512)
+                
+            else:
+                img, label, (wsi_name, tile_name, patient) = self.get_data(t)
+                # label = Variable(Tensor(label))
+                # seq_img_d = self.train_transforms.to_deterministic()
+                # seq_img_resize = self.resize_transforms.to_deterministic()
+                # img = img.numpy().astype(np.uint8)
+                img = self.resize_transforms(img)
+                # img = np.moveaxis(img, 0, 2)
+                img = self.val_transforms(img)
+
+            return img, label, (wsi_name, tile_name, patient)
+
+def save_img(img, comment):
+    home = Path.cwd().parts[1]
+    outputPath = f'/{home}/ylan/data/DeepGraft/224_128uM_annotated/debug/augments_2'
+    img = img.convert('RGB')
+    img.save(f'{outputPath}/{comment}.jpg')
+
+if __name__ == '__main__':
+    
+    from pathlib import Path
+    import os
+    import time
+    from fast_tensor_dl import FastTensorDataLoader
+    from custom_resnet50 import resnet50_baseline
+    
+    
+
+    home = Path.cwd().parts[1]
+    # train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
+    data_root = f'/{home}/ylan/data/DeepGraft/224_256uM_annotated'
+    # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
+    # label_path = f'/{home}/ylan/DeepGraft/training_tables/split_debug.json'
+    label_path = f'/{home}/ylan/data/DeepGraft/training_tables/dg_limit_5_split_PAS_HE_Jones_norm_rest_test.json'
+    # output_dir = f'/{data_root}/debug/augments'
+    # os.makedirs(output_dir, exist_ok=True)
+
+    n_classes = 2
+
+    dataset = JPGBagLoader(data_root, label_path=label_path, mode='train', n_classes=n_classes, cache=False)
+    # dataset = JPGBagLoader(data_root, label_path=label_path, mode='train', n_classes=n_classes, cache=False)
+
+    # print(dataset.get_labels(0))
+    # a = int(len(dataset)* 0.8)
+    # b = int(len(dataset) - a)
+    # train_data, valid_data = random_split(dataset, [a, b])
+    # print(dataset.dataset)
+    # a = int(len(dataset)* 0.8)
+    # b = int(len(dataset) - a)
+    # train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b])
+    # dl = FastTensorDataLoader(dataset, batch_size=1, shuffle=False)
+    dl = DataLoader(dataset, batch_size=5, num_workers=8, pin_memory=True)
+    # print(len(dl))
+    # dl = DataLoader(dataset, batch_size=1, sampler=ImbalancedDatasetSampler(dataset), num_workers=5)
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    scaler = torch.cuda.amp.GradScaler()
+
+    model_ft = resnet50_baseline(pretrained=True)
+    for param in model_ft.parameters():
+        param.requires_grad = False
+    model_ft.to(device)
+    
+    c = 0
+    label_count = [0] *n_classes
+    # print(len(dl))
+    start = time.time()
+    for item in tqdm(dl): 
+
+        if c >= 1000:
+            break
+        bag, label, (name, batch_names, patient) = item
+        print(bag.shape)
+        # print(name)
+        # print(batch_names)
+        # print(patient)
+        # print(len(batch_names))
+
+        # bag = bag.squeeze(0).float().to(device)
+        # label = label.to(device)
+        # with torch.cuda.amp.autocast():
+        #     output = model_ft(bag)
+        c += 1
+    end = time.time()
+
+    print('Bag Time: ', end-start)
\ No newline at end of file
diff --git a/code/datasets/custom_jpg_dataloader.py b/code/datasets/custom_jpg_dataloader.py
index 95d5adb331955464376b8361c38a2e1f65d76be1..8e1ce3dc47fc05b0b1008817efc80f8813edeb87 100644
--- a/code/datasets/custom_jpg_dataloader.py
+++ b/code/datasets/custom_jpg_dataloader.py
@@ -19,6 +19,7 @@ from albumentations.pytorch import ToTensorV2
 from imgaug import augmenters as iaa
 import imgaug as ia
 from torchsampler import ImbalancedDatasetSampler
+from .utils import myTransforms
 
 
 
@@ -173,19 +174,48 @@ class JPGMILDataloader(data.Dataset):
 
     def __getitem__(self, index):
         # get data
+
+        color_transforms = myTransforms.Compose([
+            myTransforms.ColorJitter(
+                brightness = (0.65, 1.35), 
+                contrast = (0.5, 1.5),
+                # saturation=(0, 2), 
+                # hue=0.3,
+                ),
+            # myTransforms.RandomChoice([myTransforms.ColorJitter(saturation=(0, 2), hue=0.3),
+            #                             myTransforms.HEDJitter(theta=0.05)]),
+            myTransforms.HEDJitter(theta=0.005),
+            
+        ])
+        train_transforms = myTransforms.Compose([
+            myTransforms.RandomChoice([myTransforms.RandomHorizontalFlip(p=0.5),
+                                        myTransforms.RandomVerticalFlip(p=0.5),
+                                        myTransforms.AutoRandomRotation()]),
+        
+            myTransforms.RandomGaussBlur(radius=[0.5, 1.5]),
+            myTransforms.RandomAffineCV2(alpha=0.1),
+            myTransforms.RandomElastic(alpha=2, sigma=0.06),
+        ])
+
+
         (batch, batch_names), label, name, patient = self.get_data(index)
         out_batch = []
-        seq_img_d = self.train_transforms.to_deterministic()
+        # seq_img_d = self.train_transforms.to_deterministic()
         
         if self.mode == 'train':
             # print(img)
             # print(.shape)
             for img in batch: # expects numpy 
                 img = img.numpy().astype(np.uint8)
+
+
+                img = color_transforms(img)
+                img = train_transforms(img)
                 # img = self.albu_transforms(image=img)
                 # print(img)
                 # print(img.shape)
-                img = seq_img_d.augment_image(img)
+                # img = seq_img_d.augment_image(img)
+
                 img = self.val_transforms(img.copy())
                 # print(img)
                 out_batch.append(img)
diff --git a/code/datasets/data_interface.py b/code/datasets/data_interface.py
index 7049b225637bf349978a69ac33a9f2743f97c8bf..065e8446f2efed51faa90e9cf897180d8fa6790a 100644
--- a/code/datasets/data_interface.py
+++ b/code/datasets/data_interface.py
@@ -12,8 +12,8 @@ from torchvision.datasets import MNIST
 from torchvision import transforms
 # from .camel_dataloader import FeatureBagLoader
 from .custom_dataloader import HDF5MILDataloader
-# from .custom_jpg_dataloader import JPGMILDataloader
-from .simple_jpg_dataloader import JPGBagLoader
+from .jpg_dataloader import JPGMILDataloader
+from .classic_jpg_dataloader import JPGBagLoader
 from .zarr_feature_dataloader_simple import ZarrFeatureBagLoader
 from .feature_dataloader import FeatureBagLoader
 from pathlib import Path
@@ -124,7 +124,7 @@ import torch
 
 class MILDataModule(pl.LightningDataModule):
 
-    def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, use_features=False, mixup=False, aug=False, *args, **kwargs):
+    def __init__(self, data_root: str, label_path: str, model_name: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, use_features=False, train_classic=False, mixup=False, aug=False, *args, **kwargs):
         super().__init__()
         self.data_root = data_root
         self.label_path = label_path
@@ -140,34 +140,43 @@ class MILDataModule(pl.LightningDataModule):
         self.seed = 1
         self.mixup = mixup
         self.aug = aug
+        self.train_classic = train_classic
+        self.max_bag_size = 1000
+        self.model_name = model_name
 
 
         self.class_weight = []
         self.cache = cache
         self.fe_transform = None
-        if not use_features: 
+        # print('use_features: ', use_features)
+        if self.train_classic: 
             self.base_dataloader = JPGBagLoader
+        elif not use_features: 
+            self.base_dataloader = JPGMILDataloader
         else: 
             self.base_dataloader = FeatureBagLoader
-            self.cache = True
+            # self.cache = True
 
     def setup(self, stage: Optional[str] = None) -> None:
         home = Path.cwd().parts[1]
 
         if stage in (None, 'fit'):
-            dataset = self.base_dataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes, cache=self.cache, mixup=self.mixup, aug=self.aug)
+            self.train_data = self.base_dataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes, cache=self.cache, mixup=self.mixup, aug=self.aug, model=self.model_name)
+            self.valid_data = self.base_dataloader(self.data_root, label_path=self.label_path, mode='val', n_classes=self.n_classes, cache=self.cache, model=self.model_name)
+
             # dataset = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes)
-            print(len(dataset))
-            a = int(len(dataset)* 0.8)
-            b = int(len(dataset) - a)
-            self.train_data, self.valid_data = random_split(dataset, [a, b])
+            print('Train Data: ', len(self.train_data))
+            print('Val Data: ', len(self.valid_data))
+            # a = int(len(dataset)* 0.8)
+            # b = int(len(dataset) - a)
+            # self.train_data, self.valid_data = random_split(dataset, [a, b])
 
             # self.weights = self.get_weights(dataset)
 
 
 
         if stage in (None, 'test'):
-            self.test_data = self.base_dataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, cache=False)
+            self.test_data = self.base_dataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, cache=False, model=self.model_name, mixup=False, aug=False)
             print(len(self.test_data))
 
         return super().setup(stage=stage)
@@ -177,13 +186,17 @@ class MILDataModule(pl.LightningDataModule):
     def train_dataloader(self) -> DataLoader:
         # return DataLoader(self.train_data,  batch_size = self.batch_size, num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True, 
         # return DataLoader(self.train_data,  batch_size = self.batch_size, sampler = WeightedRandomSampler(self.weights, len(self.weights)), num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True, 
-        return DataLoader(self.train_data,  batch_size = self.batch_size, sampler=ImbalancedDatasetSampler(self.train_data), num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+        if self.train_classic:
+            return DataLoader(self.train_data, batch_size = self.batch_size, num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+        else:
+            return DataLoader(self.train_data,  batch_size = self.batch_size, sampler=ImbalancedDatasetSampler(self.train_data), num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+            # return DataLoader(self.train_data,  batch_size = self.batch_size, sampler=ImbalancedDatasetSampler(self.train_data), num_workers=self.num_workers, collate_fn=self.custom_collate) #batch_transforms=self.transform, pseudo_batch_dim=True, 
         #sampler=ImbalancedDatasetSampler(self.train_data)
     def val_dataloader(self) -> DataLoader:
         return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers)
     
     def test_dataloader(self) -> DataLoader:
-        return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers)
+        return DataLoader(self.test_data, batch_size = 1, num_workers=self.num_workers)
 
     def get_weights(self, dataset):
 
@@ -201,6 +214,69 @@ class MILDataModule(pl.LightningDataModule):
 
         return torch.DoubleTensor(weights)
     
+    def custom_collate(self, batch):
+        # print(len(batch))
+        # print(len(batch))
+        for i in batch:
+            
+            bag, label, (wsi_name, patient) = i
+            print(bag.shape)
+        
+        # print(bag.shape)
+
+        # bag_size = bag.shape[0]
+        # bag_idxs = torch.randperm(bag_size)[:self.max_bag_size]
+        # # bag_idxs = torch.randperm(bag_size)[:int(self.max_bag_size*(1-self.drop_rate))]
+        # out_bag = bag[bag_idxs, :]
+        # if self.mixup:
+        #     out_bag = self.get_mixup_bag(out_bag)
+        #     # batch_coords = 
+        # if out_bag.shape[0] < self.max_bag_size:
+        #     out_bag = torch.cat((out_bag, torch.zeros(self.max_bag_size-out_bag.shape[0], out_bag.shape[1])))
+
+        # # shuffle again
+        # out_bag_idxs = torch.randperm(out_bag.shape[0])
+        # out_bag = out_bag[out_bag_idxs]
+        # batch_coords = batch_coords[bag_idxs]
+
+        
+        # return out_bag, label, (wsi_name, batch_coords, patient)
+        return batch
+
+        
+    def get_mixup_bag(self, bag):
+
+        bag_size = bag.shape[0]
+
+        a = torch.rand([bag_size])
+        b = 0.6
+        rand_x = torch.randint(0, bag_size, [bag_size,])
+        rand_y = torch.randint(0, bag_size, [bag_size,])
+
+        bag_x = bag[rand_x, :]
+        bag_y = bag[rand_y, :]
+
+        temp_bag = (bag_x.t()*a).t() + (bag_y.t()*(1.0-a)).t()
+        # print('temp_bag: ', temp_bag.shape)
+
+        if bag_size < self.max_bag_size:
+            diff = self.max_bag_size - bag_size
+            bag_idxs = torch.randperm(bag_size)[:diff]
+            
+            # print('bag: ', bag.shape)
+            # print('bag_idxs: ', bag_idxs.shape)
+            mixup_bag = torch.cat((bag, temp_bag[bag_idxs, :]))
+            # print('mixup_bag: ', mixup_bag.shape)
+        else:
+            random_sample_list = torch.rand(bag_size)
+            mixup_bag = [bag[i] if random_sample_list[i] else temp_bag[i] > b for i in range(bag_size)] #make pytorch native?!
+            mixup_bag = torch.stack(mixup_bag)
+            # print('else')
+            # print(mixup_bag.shape)
+
+        return mixup_bag
+
+
 
 class DataModule(pl.LightningDataModule):
 
@@ -240,7 +316,7 @@ class DataModule(pl.LightningDataModule):
         return super().setup(stage=stage)
 
     def train_dataloader(self) -> DataLoader:
-        return DataLoader(self.train_data,  self.batch_size, shuffle=False,) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+        return DataLoader(self.train_data,  self.batch_size, sampler=ImbalancedDatasetSampler(self.train_data),shuffle=False,) #batch_transforms=self.transform, pseudo_batch_dim=True, 
         #sampler=ImbalancedDatasetSampler(self.train_data),
     def val_dataloader(self) -> DataLoader:
         return DataLoader(self.valid_data, batch_size = self.batch_size)
@@ -325,3 +401,4 @@ class CrossVal_MILDataModule(BaseKFoldDataModule):
 
 
 
+# if __name__ == '__main__':
diff --git a/code/datasets/feature_dataloader.py b/code/datasets/feature_dataloader.py
index 3e9dbb567b7a791386e56a86127958df8e617e0b..21b80810dee28eba90427b97e6986a84f1a8c3a1 100644
--- a/code/datasets/feature_dataloader.py
+++ b/code/datasets/feature_dataloader.py
@@ -23,13 +23,14 @@ import h5py
 
 
 class FeatureBagLoader(data.Dataset):
-    def __init__(self, file_path, label_path, mode, n_classes, cache=False, mixup=False, aug=False, data_cache_size=5000, max_bag_size=1000):
+    def __init__(self, file_path, label_path, mode, model, n_classes, cache=False, mixup=False, aug=False, data_cache_size=5000, max_bag_size=1000):
         super().__init__()
 
         self.data_info = []
         self.data_cache = {}
         self.slideLabelDict = {}
         self.files = []
+        self.labels = []
         self.data_cache_size = data_cache_size
         self.mode = mode
         self.file_path = file_path
@@ -48,7 +49,7 @@ class FeatureBagLoader(data.Dataset):
         self.missing = []
 
         home = Path.cwd().parts[1]
-        self.slide_patient_dict_path = f'/{home}/ylan/DeepGraft/training_tables/slide_patient_dict_an.json'
+        self.slide_patient_dict_path = f'/{home}/ylan/data/DeepGraft/training_tables/slide_patient_dict_an.json'
         with open(self.slide_patient_dict_path, 'r') as f:
             self.slide_patient_dict = json.load(f)
 
@@ -59,25 +60,28 @@ class FeatureBagLoader(data.Dataset):
             # print(len(temp_slide_label_dict))
             for (x,y) in temp_slide_label_dict:
                 
+                x = x.replace('FEATURES_RETCCL_2048', 'FEATURES_RETCCL_2048_HED')
                 x_name = Path(x).stem
-                x_path_list = [Path(self.file_path)/x]
-                # x_name = x.stem
-                # x_path_list = [Path(self.file_path)/ x for (x,y) in temp_slide_label_dict]
-                if self.aug:
-                    for i in range(5):
-                        aug_path = Path(self.file_path)/f'{x}_aug{i}'
-                        x_path_list.append(aug_path)
-
-                for x_path in x_path_list: 
-                    
-                    if x_path.exists():
-                        self.slideLabelDict[x_name] = y
-                        self.files.append(x_path)
-                    elif Path(str(x_path) + '.zarr').exists():
-                        self.slideLabelDict[x] = y
-                        self.files.append(str(x_path)+'.zarr')
-                    else:
-                        self.missing.append(x)
+                if x_name in self.slide_patient_dict.keys():
+                    x_path_list = [Path(self.file_path)/x]
+                    # x_name = x.stem
+                    # x_path_list = [Path(self.file_path)/ x for (x,y) in temp_slide_label_dict]
+                    if self.aug:
+                        for i in range(10):
+                            aug_path = Path(self.file_path)/f'{x}_aug{i}'
+                            x_path_list.append(aug_path)
+
+                    for x_path in x_path_list: 
+                        
+                        if x_path.exists():
+                            self.slideLabelDict[x_name] = y
+                            self.labels.append(int(y))
+                            self.files.append(x_path)
+                        elif Path(str(x_path) + '.zarr').exists():
+                            self.slideLabelDict[x] = y
+                            self.files.append(str(x_path)+'.zarr')
+                        else:
+                            self.missing.append(x)
                 # print(x, y)
                 # x_complete_path = Path(self.file_path)/Path(x)
                 # for cohort in Path(self.file_path).iterdir():
@@ -129,22 +133,24 @@ class FeatureBagLoader(data.Dataset):
         
 
         self.feature_bags = []
-        self.labels = []
+        
         self.wsi_names = []
         self.coords = []
         self.patients = []
         if self.cache:
             for t in tqdm(self.files):
                 # zarr_t = str(t) + '.zarr'
-                batch, label, (wsi_name, batch_coords, patient) = self.get_data(t)
+                batch, (wsi_name, batch_coords, patient) = self.get_data(t)
 
                 # print(label)
-                self.labels.append(label)
+                # self.labels.append(label)
                 self.feature_bags.append(batch)
                 self.wsi_names.append(wsi_name)
                 self.coords.append(batch_coords)
                 self.patients.append(patient)
-        
+        # else: 
+        #     for t in tqdm(self.files):
+        #         self.labels = 
 
     def get_data(self, file_path):
         
@@ -154,7 +160,7 @@ class FeatureBagLoader(data.Dataset):
         if wsi_name.split('_')[-1][:3] == 'aug':
             wsi_name = '_'.join(wsi_name.split('_')[:-1])
         # if wsi_name in self.slideLabelDict:
-        label = self.slideLabelDict[wsi_name]
+        # label = self.slideLabelDict[wsi_name]
         patient = self.slide_patient_dict[wsi_name]
 
         if Path(file_path).suffix == '.zarr':
@@ -171,11 +177,11 @@ class FeatureBagLoader(data.Dataset):
         # np_bag = np.array(z['data'][:])
         # np_bag = np.array(zarr.open(file_path, 'r')).astype(np.uint8)
         # label = torch.as_tensor(label)
-        label = int(label)
+        # label = int(label)
         wsi_bag = torch.from_numpy(np_bag)
         batch_coords = torch.from_numpy(coords)
 
-        return wsi_bag, label, (wsi_name, batch_coords, patient)
+        return wsi_bag, (wsi_name, batch_coords, patient)
     
     def get_labels(self, indices):
         # for i in indices: 
@@ -224,10 +230,6 @@ class FeatureBagLoader(data.Dataset):
         bag_x = bag[rand_x, :]
         bag_y = bag[rand_y, :]
 
-        # print('bag_x: ', bag_x.shape)
-        # print('bag_y: ', bag_y.shape)
-        # print('a*bag_x: ', (a*bag_x).shape)
-        # print('(1.0-a)*bag_y: ', ((1.0-a)*bag_y).shape)
 
         temp_bag = (bag_x.t()*a).t() + (bag_y.t()*(1.0-a)).t()
         # print('temp_bag: ', temp_bag.shape)
@@ -275,7 +277,9 @@ class FeatureBagLoader(data.Dataset):
             # return wsi, label, (wsi_name, batch_coords, patient)
         else:
             t = self.files[index]
-            bag, label, (wsi_name, batch_coords, patient) = self.get_data(t)
+            label = self.labels[index]
+            bag, (wsi_name, batch_coords, patient) = self.get_data(t)
+            # print(bag.shape)
             # label = torch.as_tensor(label)
             # label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
                 # self.labels.append(label)
@@ -283,34 +287,297 @@ class FeatureBagLoader(data.Dataset):
                 # self.wsi_names.append(wsi_name)
                 # self.name_batches.append(name_batch)
                 # self.patients.append(patient)
-        if self.mode == 'train':
-            bag_size = bag.shape[0]
+            if self.mode == 'train':
+                bag_size = bag.shape[0]
+
+                bag_idxs = torch.randperm(bag_size)[:self.max_bag_size]
+                # bag_idxs = torch.randperm(bag_size)[:int(self.max_bag_size*(1-self.drop_rate))]
+                out_bag = bag[bag_idxs, :]
+                if self.mixup:
+                    out_bag = self.get_mixup_bag(out_bag)
+                    # batch_coords = 
+                if out_bag.shape[0] < self.max_bag_size:
+                    out_bag = torch.cat((out_bag, torch.zeros(self.max_bag_size-out_bag.shape[0], out_bag.shape[1])))
+
+                # shuffle again
+                out_bag_idxs = torch.randperm(out_bag.shape[0])
+                out_bag = out_bag[out_bag_idxs]
+
+
+                # batch_coords only useful for test
+                batch_coords = batch_coords[bag_idxs]
+                # out_bag = bag
+
+            # mixup? Linear combination of 2 vectors
+            # add noise
+
+
+            else: 
+                bag_size = bag.shape[0]
+                bag_idxs = torch.randperm(bag_size)[:self.max_bag_size]
+                out_bag = bag[bag_idxs, :]
+                if out_bag.shape[0] < self.max_bag_size:
+                    out_bag = torch.cat((out_bag, torch.zeros(self.max_bag_size-out_bag.shape[0], out_bag.shape[1])))
+                
+
+        return out_bag, label, (wsi_name, patient)
+        # return out_bag, label, (wsi_name, batch_coords, patient)
+
+class FeatureBagLoader_Mixed(data.Dataset):
+    def __init__(self, file_path, label_path, mode, n_classes, cache=False, mixup=False, aug=False, data_cache_size=5000, max_bag_size=1000):
+        super().__init__()
+
+        self.data_info = []
+        self.data_cache = {}
+        self.slideLabelDict = {}
+        self.files = []
+        self.labels = []
+        self.data_cache_size = data_cache_size
+        self.mode = mode
+        self.file_path = file_path
+        # self.csv_path = csv_path
+        self.label_path = label_path
+        self.n_classes = n_classes
+        self.max_bag_size = max_bag_size
+        self.drop_rate = 0.2
+        # self.min_bag_size = 120
+        self.empty_slides = []
+        self.corrupt_slides = []
+        self.cache = cache
+        self.mixup = mixup
+        self.aug = aug
+        
+        self.missing = []
+
+        home = Path.cwd().parts[1]
+        self.slide_patient_dict_path = f'/{home}/ylan/DeepGraft/training_tables/slide_patient_dict_an.json'
+        with open(self.slide_patient_dict_path, 'r') as f:
+            self.slide_patient_dict = json.load(f)
+
+        # read labels and slide_path from csv
+        with open(self.label_path, 'r') as f:
+            json_dict = json.load(f)
+            temp_slide_label_dict = json_dict[self.mode]
+            # print(len(temp_slide_label_dict))
+            for (x,y) in temp_slide_label_dict:
+                
+                x_name = Path(x).stem
+                x_path_list = [Path(self.file_path)/x]
+                # x_name = x.stem
+                # x_path_list = [Path(self.file_path)/ x for (x,y) in temp_slide_label_dict]
+                if self.aug:
+                    for i in range(5):
+                        aug_path = Path(self.file_path)/f'{x}_aug{i}'
+                        x_path_list.append(aug_path)
+
+                for x_path in x_path_list: 
+                    
+                    if x_path.exists():
+                        self.slideLabelDict[x_name] = y
+                        self.labels.append(int(y))
+                        self.files.append(x_path)
+                        for patch_path in x_path.iterdir():
+                            self.files.append((patch_path, x_name, y))
+
+
+        self.feature_bags = []
+        
+        self.wsi_names = []
+        self.coords = []
+        self.patients = []
+
+        if self.cache:
+            for t in tqdm(self.files):
+                # zarr_t = str(t) + '.zarr'
+                batch, (wsi_name, batch_coords, patient) = self.get_data(t)
+
+                # print(label)
+                # self.labels.append(label)
+                self.feature_bags.append(batch)
+                self.wsi_names.append(wsi_name)
+                self.coords.append(batch_coords)
+                self.patients.append(patient)
+        # else: 
+        #     for t in tqdm(self.files):
+        #         self.labels = 
+
+    # def create_bag(self):
+
+
+
+    def get_data(self, file_path):
+        
+        batch_names=[] #add function for name_batch read out
+
+        wsi_name = Path(file_path).stem
+        if wsi_name.split('_')[-1][:3] == 'aug':
+            wsi_name = '_'.join(wsi_name.split('_')[:-1])
+        # if wsi_name in self.slideLabelDict:
+        # label = self.slideLabelDict[wsi_name]
+        patient = self.slide_patient_dict[wsi_name]
+
+        if Path(file_path).suffix == '.zarr':
+            z = zarr.open(file_path, 'r')
+            np_bag = np.array(z['data'][:])
+            coords = np.array(z['coords'][:])
+        else:
+            with h5py.File(file_path, 'r') as hdf5_file:
+                np_bag = hdf5_file['features'][:]
+                coords = hdf5_file['coords'][:]
+
+        # np_bag = torch.load(file_path)
+        # z = zarr.open(file_path, 'r')
+        # np_bag = np.array(z['data'][:])
+        # np_bag = np.array(zarr.open(file_path, 'r')).astype(np.uint8)
+        # label = torch.as_tensor(label)
+        # label = int(label)
+        wsi_bag = torch.from_numpy(np_bag)
+        batch_coords = torch.from_numpy(coords)
+
+        return wsi_bag, (wsi_name, batch_coords, patient)
+    
+    def get_labels(self, indices):
+        # for i in indices: 
+        #     print(self.labels[i])
+        return [self.labels[i] for i in indices]
+
+
+    def to_fixed_size_bag(self, bag, names, bag_size: int = 512):
+
+        #duplicate bag instances unitl 
+
+        bag_idxs = torch.randperm(bag.shape[0])[:bag_size]
+        bag_samples = bag[bag_idxs]
+        name_samples = [names[i] for i in bag_idxs]
+        # bag_sample_names = [bag_names[i] for i in bag_idxs]
+        # q, r  = divmod(bag_size, bag_samples.shape[0])
+        # if q > 0:
+        #     bag_samples = torch.cat([bag_samples]*q, 0)
+
+        # self_padded = torch.cat([bag_samples, bag_samples[:r,:, :, :]])
+
+        # zero-pad if we don't have enough samples
+        # zero_padded = torch.cat((bag_samples,
+        #                         torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+
+        return bag_samples, name_samples, min(bag_size, len(bag))
+
+    def data_dropout(self, bag, batch_names, drop_rate):
+        # bag_size = self.max_bag_size
+        bag_size = bag.shape[0]
+        bag_idxs = torch.randperm(self.max_bag_size)[:int(bag_size*(1-drop_rate))]
+        bag_samples = bag[bag_idxs]
+        name_samples = [batch_names[i] for i in bag_idxs]
 
-            bag_idxs = torch.randperm(bag_size)[:self.max_bag_size]
-            # bag_idxs = torch.randperm(bag_size)[:int(self.max_bag_size*(1-self.drop_rate))]
-            out_bag = bag[bag_idxs, :]
-            if self.mixup:
-                out_bag = self.get_mixup_bag(out_bag)
-                # batch_coords = 
-            if out_bag.shape[0] < self.max_bag_size:
-                out_bag = torch.cat((out_bag, torch.zeros(self.max_bag_size-out_bag.shape[0], out_bag.shape[1])))
+        return bag_samples, name_samples
 
-            # shuffle again
-            out_bag_idxs = torch.randperm(out_bag.shape[0])
-            out_bag = out_bag[out_bag_idxs]
+    def get_mixup_bag(self, bag):
 
+        bag_size = bag.shape[0]
 
-            # batch_coords only useful for test
-            batch_coords = batch_coords[bag_idxs]
+        a = torch.rand([bag_size])
+        b = 0.6
+        rand_x = torch.randint(0, bag_size, [bag_size,])
+        rand_y = torch.randint(0, bag_size, [bag_size,])
+
+        bag_x = bag[rand_x, :]
+        bag_y = bag[rand_y, :]
+
+
+        temp_bag = (bag_x.t()*a).t() + (bag_y.t()*(1.0-a)).t()
+        # print('temp_bag: ', temp_bag.shape)
+
+        if bag_size < self.max_bag_size:
+            diff = self.max_bag_size - bag_size
+            bag_idxs = torch.randperm(bag_size)[:diff]
+            
+            # print('bag: ', bag.shape)
+            # print('bag_idxs: ', bag_idxs.shape)
+            mixup_bag = torch.cat((bag, temp_bag[bag_idxs, :]))
+            # print('mixup_bag: ', mixup_bag.shape)
+        else:
+            random_sample_list = torch.rand(bag_size)
+            mixup_bag = [bag[i] if random_sample_list[i] else temp_bag[i] > b for i in range(bag_size)] #make pytorch native?!
+            mixup_bag = torch.stack(mixup_bag)
+            # print('else')
+            # print(mixup_bag.shape)
+
+        return mixup_bag
+
+    def __len__(self):
+        return len(self.files)
+
+    def __getitem__(self, index):
+
+        if self.cache:
+            label = self.labels[index]
+            bag = self.feature_bags[index]
             
+        
+            
+            # label = Variable(Tensor(label))
+            # label = torch.as_tensor(label)
+            # label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+            wsi_name = self.wsi_names[index]
+            batch_coords = self.coords[index]
+            patient = self.patients[index]
+
+            
+            #random dropout
+            #shuffle
+
+            # feats = Variable(Tensor(feats))
+            # return wsi, label, (wsi_name, batch_coords, patient)
+        else:
+            t = self.files[index]
+            label = self.labels[index]
+            bag, (wsi_name, batch_coords, patient) = self.get_data(t)
+            # print(bag.shape)
+            # label = torch.as_tensor(label)
+            # label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+                # self.labels.append(label)
+                # self.feature_bags.append(batch)
+                # self.wsi_names.append(wsi_name)
+                # self.name_batches.append(name_batch)
+                # self.patients.append(patient)
+            if self.mode == 'train':
+                bag_size = bag.shape[0]
+
+                bag_idxs = torch.randperm(bag_size)[:self.max_bag_size]
+                # bag_idxs = torch.randperm(bag_size)[:int(self.max_bag_size*(1-self.drop_rate))]
+                out_bag = bag[bag_idxs, :]
+                if self.mixup:
+                    out_bag = self.get_mixup_bag(out_bag)
+                    # batch_coords = 
+                if out_bag.shape[0] < self.max_bag_size:
+                    out_bag = torch.cat((out_bag, torch.zeros(self.max_bag_size-out_bag.shape[0], out_bag.shape[1])))
+
+                # shuffle again
+                out_bag_idxs = torch.randperm(out_bag.shape[0])
+                out_bag = out_bag[out_bag_idxs]
+
+
+                # batch_coords only useful for test
+                batch_coords = batch_coords[bag_idxs]
+                # out_bag = bag
+
+            # mixup? Linear combination of 2 vectors
+            # add noise
+
 
-        # mixup? Linear combination of 2 vectors
-        # add noise
+            # elif self.mode == 'val': 
+            #     bag_size = bag.shape[0]
+            #     bag_idxs = torch.randperm(bag_size)[:self.max_bag_size]
+            #     out_bag = bag[bag_idxs, :]
+            #     if out_bag.shape[0] < self.max_bag_size:
+            #         out_bag = torch.cat((out_bag, torch.zeros(self.max_bag_size-out_bag.shape[0], out_bag.shape[1])))
+            else:
+                # bag_size = bag.shape[0]
+                out_bag = bag
 
 
-        else: out_bag = bag
+        return out_bag, label, (wsi_name, patient)
 
-        return out_bag, label, (wsi_name, batch_coords, patient)
 
 if __name__ == '__main__':
     
@@ -328,23 +595,24 @@ if __name__ == '__main__':
     # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
     # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
     # label_path = f'/{home}/ylan/DeepGraft/training_tables/split_debug.json'
-    label_path = f'/{home}/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest_test.json'
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/dg_limit_5_split_PAS_HE_Jones_norm_rest_test.json'
     output_dir = f'/{data_root}/debug/augments'
     os.makedirs(output_dir, exist_ok=True)
 
     n_classes = 2
 
-    dataset = FeatureBagLoader(data_root, label_path=label_path, mode='train', cache=False, mixup=True, aug=True, n_classes=n_classes)
+    train_dataset = FeatureBagLoader(data_root, label_path=label_path, mode='train', cache=False, mixup=True, aug=True, n_classes=n_classes)
+    val_dataset = FeatureBagLoader(data_root, label_path=label_path, mode='val', cache=False, mixup=False, aug=False, n_classes=n_classes)
 
     test_dataset = FeatureBagLoader(data_root, label_path=label_path, mode='test', cache=False, n_classes=n_classes)
 
     # print(dataset.get_labels(0))
-    a = int(len(dataset)* 0.8)
-    b = int(len(dataset) - a)
-    train_data, valid_data = random_split(dataset, [a, b])
+    # a = int(len(dataset)* 0.8)
+    # b = int(len(dataset) - a)
+    # train_data, valid_data = random_split(dataset, [a, b])
 
-    train_dl = DataLoader(train_data, batch_size=1, num_workers=5)
-    valid_dl = DataLoader(valid_data, batch_size=1, num_workers=5)
+    train_dl = DataLoader(train_dataset, batch_size=1, sampler=ImbalancedDatasetSampler(train_dataset), num_workers=5)
+    valid_dl = DataLoader(val_dataset, batch_size=1, num_workers=5)
     test_dl = DataLoader(test_dataset)
 
     print('train_dl: ', len(train_dl))
@@ -371,7 +639,7 @@ if __name__ == '__main__':
     # start = time.time()
     for i in range(epochs):
         start = time.time()
-        for item in tqdm(train_dl): 
+        for item in tqdm(valid_dl): 
 
             # if c >= 10:
             #     break
diff --git a/code/datasets/jpg_dataloader.py b/code/datasets/jpg_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..c983d555725aa995c290970c751ec00be4275905
--- /dev/null
+++ b/code/datasets/jpg_dataloader.py
@@ -0,0 +1,434 @@
+# import pandas as pd
+
+import numpy as np
+import torch
+from torch import Tensor
+from torch.utils import data
+from torch.utils.data import random_split, DataLoader
+from torch.autograd import Variable
+from torch.nn.functional import one_hot
+import torch.utils.data as data_utils
+import torchvision.transforms as transforms
+import pandas as pd
+from sklearn.utils import shuffle
+from pathlib import Path
+from tqdm import tqdm
+from PIL import Image
+import cv2
+import json
+from imgaug import augmenters as iaa
+from torchsampler import ImbalancedDatasetSampler
+from .utils import myTransforms
+
+
+class JPGMILDataloader(data_utils.Dataset):
+    def __init__(self, file_path, label_path, mode, model, n_classes, data_cache_size=100, max_bag_size=1000, cache=False, mixup=False, aug=False):
+        super().__init__()
+
+        self.data_info = []
+        self.data_cache = {}
+        self.slideLabelDict = {}
+        self.files = []
+        self.data_cache_size = data_cache_size
+        self.mode = mode
+        self.file_path = file_path
+        # self.csv_path = csv_path
+        self.label_path = label_path
+        self.n_classes = n_classes
+        self.max_bag_size = max_bag_size
+        self.min_bag_size = 50
+        self.empty_slides = []
+        self.corrupt_slides = []
+        self.cache = False
+        self.labels = []
+
+        # self.features = []
+        # self.labels = []
+        # self.wsi_names = []
+        # self.name_batches = []
+        # self.patients = []
+        
+        # read labels and slide_path from csv
+        with open(self.label_path, 'r') as f:
+            json_dict = json.load(f)
+            temp_slide_label_dict = json_dict[self.mode]
+            # print(len(temp_slide_label_dict))
+            for (x,y) in temp_slide_label_dict:
+                x = x.replace('FEATURES_RETCCL_2048', 'BLOCKS')
+                # print(x)
+                x_name = Path(x).stem
+                x_path_list = [Path(self.file_path)/x]
+                for x_path in x_path_list:
+                    if x_path.exists():
+                        # print(len(list(x_path.glob('*'))))
+                        self.slideLabelDict[x_name] = y
+                        self.labels += [int(y)]*len(list(x_path.glob('*')))
+                        # self.labels.append(int(y))
+                        self.files.append(x_path)
+
+        # with open(self.label_path, 'r') as f:
+        #     temp_slide_label_dict = json.load(f)[mode]
+        #     print(len(temp_slide_label_dict))
+        #     for (x, y) in temp_slide_label_dict:
+        #         x = Path(x).stem 
+        #         # x_complete_path = Path(self.file_path)/Path(x)
+        #         for cohort in Path(self.file_path).iterdir():
+        #             x_complete_path = Path(self.file_path) / cohort / 'BLOCKS' / Path(x)
+        #             if x_complete_path.is_dir():
+        #                 if len(list(x_complete_path.iterdir())) > self.min_bag_size:
+        #                 # print(x_complete_path)
+        #                     self.slideLabelDict[x] = y
+        #                     self.files.append(x_complete_path)
+        #                 else: self.empty_slides.append(x_complete_path)
+        
+        home = Path.cwd().parts[1]
+        # self.slide_patient_dict_path = f'/{home}/ylan/DeepGraft/training_tables/slide_patient_dict.json'
+        self.slide_patient_dict_path = Path(self.label_path).parent / 'slide_patient_dict_an.json'
+        with open(self.slide_patient_dict_path, 'r') as f:
+            self.slide_patient_dict = json.load(f)
+
+        # def get_transforms_2():
+        
+        self.color_transforms = myTransforms.Compose([
+            myTransforms.ColorJitter(
+                brightness = (0.65, 1.35), 
+                contrast = (0.5, 1.5),
+                ),
+            myTransforms.HEDJitter(theta=0.005),
+            
+        ])
+        self.train_transforms = myTransforms.Compose([
+            myTransforms.RandomChoice([myTransforms.RandomHorizontalFlip(p=0.5),
+                                        myTransforms.RandomVerticalFlip(p=0.5),
+                                        myTransforms.AutoRandomRotation()]),
+        
+            myTransforms.RandomGaussBlur(radius=[0.5, 1.5]),
+            myTransforms.RandomAffineCV2(alpha=0.1),
+            myTransforms.RandomElastic(alpha=2, sigma=0.06),
+        ])
+
+
+
+        # sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1")
+        # sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2")
+        # sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3")
+        # sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4")
+        # sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5")
+
+        # self.resize_transforms = iaa.Sequential([
+        #     myTransforms.Resize(size=(299,299)),
+        # ], name='resizeAug')
+
+        # self.train_transforms = iaa.Sequential([
+        #     iaa.AddToHueAndSaturation(value=(-30, 30), name="MyHSV"), #13
+        #     sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")),
+        #     iaa.Fliplr(0.5, name="MyFlipLR"),
+        #     iaa.Flipud(0.5, name="MyFlipUD"),
+        #     sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")),
+        #     # iaa.OneOf([
+        #     #     sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")),
+        #     #     sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")),
+        #     #     sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine"))
+        #     # ], name="MyOneOf")
+
+        # ], name="MyAug")
+        self.val_transforms = transforms.Compose([
+            # 
+            transforms.ToTensor(),
+            transforms.Normalize(
+                mean=[0.485, 0.456, 0.406],
+                std=[0.229, 0.224, 0.225],
+            ),
+            # RangeNormalization(),
+        ])
+
+
+
+
+        
+        # if self.cache:
+        #     if mode=='train':
+        #         seq_img_d = self.train_transforms.to_deterministic()
+                
+        #         # with tqdm(total=len(self.files)) as pbar:
+
+        #         for t in tqdm(self.files):
+        #             batch, label, (wsi_name, name_batch, patient) = self.get_data(t)
+        #             # print('label: ', label)
+        #             out_batch = []
+        #             for img in batch: 
+        #                 img = img.numpy().astype(np.uint8)
+        #                 img = seq_img_d.augment_image(img)
+        #                 img = self.val_transforms(img.copy())
+        #                 out_batch.append(img)
+        #             # ft = ft.view(-1, 512)
+                    
+        #             out_batch = torch.stack(out_batch)
+        #             self.labels.append(label)
+        #             self.features.append(out_batch)
+        #             self.wsi_names.append(wsi_name)
+        #             self.name_batches.append(name_batch)
+        #             self.patients.append(patient)
+        #                 # pbar.update()
+        #     else: 
+        #         # with tqdm(total=len(self.file_path)) as pbar:
+        #         for t in tqdm(self.file_path):
+        #             batch, label, (wsi_name, name_batch, patient) = self.get_data(t)
+        #             out_batch = []
+        #             for img in batch: 
+        #                 img = img.numpy().astype(np.uint8)
+        #                 img = self.val_transforms(img.copy())
+        #                 out_batch.append(img)
+        #             # ft = ft.view(-1, 512)
+        #             out_batch = torch.stack(out_batch)
+        #             self.labels.append(label)
+        #             self.features.append(out_batch)
+        #             self.wsi_names.append(wsi_name)
+        #             self.name_batches.append(name_batch)
+        #             self.patients.append(patient)
+                        # pbar.update()
+        # print(self.get_bag_feats(self.train_path))
+        # self.r = np.random.RandomState(seed)
+
+        # self.num_in_train = 60000
+        # self.num_in_test = 10000
+
+        # if self.train:
+        #     self.train_bags_list, self.train_labels_list = self._create_bags()
+        # else:
+        #     self.test_bags_list, self.test_labels_list = self._create_bags()
+
+    def get_data(self, file_path):
+        
+        color_transforms = myTransforms.Compose([
+            myTransforms.ColorJitter(
+                brightness = (0.65, 1.35), 
+                contrast = (0.5, 1.5),
+                # saturation=(0, 2), 
+                # hue=0.3,
+                ),
+            # myTransforms.RandomChoice([myTransforms.ColorJitter(saturation=(0, 2), hue=0.3),
+            #                             myTransforms.HEDJitter(theta=0.05)]),
+            myTransforms.HEDJitter(theta=0.005),
+            
+        ])
+        train_transforms = myTransforms.Compose([
+            myTransforms.RandomChoice([myTransforms.RandomHorizontalFlip(p=0.5),
+                                        myTransforms.RandomVerticalFlip(p=0.5),
+                                        myTransforms.AutoRandomRotation()]),
+        
+            myTransforms.RandomGaussBlur(radius=[0.5, 1.5]),
+            myTransforms.RandomAffineCV2(alpha=0.1),
+            myTransforms.RandomElastic(alpha=2, sigma=0.06),
+        ])
+
+        wsi_batch=[]
+        name_batch=[]
+        
+        for tile_path in Path(file_path).iterdir():
+            img = Image.open(tile_path)
+            if self.mode == 'train':
+            
+                img = self.color_transforms(img)
+                img = self.train_transforms(img)
+            img = self.val_transforms(img)
+            # img = np.asarray(Image.open(tile_path)).astype(np.uint8)
+            # img = np.moveaxis(img, 2, 0)
+            # print(img.shape)
+            # img = torch.from_numpy(img)
+            wsi_batch.append(img)
+            name_batch.append(tile_path.stem)
+
+        wsi_batch = torch.stack(wsi_batch)
+
+        # if wsi_batch.size(0) > self.max_bag_size:
+        
+
+        wsi_name = Path(file_path).stem
+        try:
+            label = self.slideLabelDict[wsi_name]
+        except KeyError:
+            print(f'{wsi_name} is not included in label file {self.label_path}')
+
+        
+
+        try:
+            patient = self.slide_patient_dict[wsi_name]
+        except KeyError:
+            print(f'{wsi_name} is not included in label file {self.slide_patient_dict_path}')
+
+        return wsi_batch, label, (wsi_name, name_batch, patient)
+    
+    def get_labels(self, indices):
+        return [self.labels[i] for i in indices]
+
+
+    def to_fixed_size_bag(self, bag, bag_size: int = 512):
+
+        #duplicate bag instances unitl 
+
+        bag_idxs = torch.randperm(bag.shape[0])[:bag_size]
+        bag_samples = bag[bag_idxs]
+        # name_samples = [names[i] for i in bag_idxs]
+
+        # bag_sample_names = [bag_names[i] for i in bag_idxs]
+        # q, r  = divmod(bag_size, bag_samples.shape[0])
+        # if q > 0:
+        #     bag_samples = torch.cat([bag_samples]*q, 0)
+
+        # self_padded = torch.cat([bag_samples, bag_samples[:r,:, :, :]])
+
+        # zero-pad if we don't have enough samples
+        zero_padded = torch.cat((bag_samples,
+                                torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+
+        return zero_padded, min(bag_size, len(bag))
+
+    def data_dropout(self, bag, drop_rate):
+        bag_size = bag.shape[0]
+        bag_idxs = torch.randperm(bag_size)[:int(bag_size*(1-drop_rate))]
+        bag_samples = bag[bag_idxs]
+        # name_samples = [batch_names[i] for i in bag_idxs]
+
+        return bag_samples
+
+    def __len__(self):
+        return len(self.files)
+
+    def __getitem__(self, index):
+
+        if self.cache:
+            label = self.labels[index]
+            wsi = self.features[index]
+            label = int(label)
+            wsi_name = self.wsi_names[index]
+            name_batch = self.name_batches[index]
+            patient = self.patients[index]
+            # feats = Variable(Tensor(feats))
+            return wsi, label, (wsi_name, name_batch, patient)
+        else:
+            t = self.files[index]
+            # label = self.labels[index]
+            if self.mode=='train':
+
+                batch, label, (wsi_name, name_batch, patient) = self.get_data(t)
+                batch, _ = self.to_fixed_size_bag(batch, self.max_bag_size)
+                batch = self.data_dropout(batch, drop_rate=0.1)
+                # print(batch.shape)
+                # # label = Variable(Tensor(label))
+
+                # # wsi = Variable(Tensor(wsi_batch))
+                # out_batch = []
+
+                # # seq_img_d = self.train_transforms.to_deterministic()
+                # # seq_img_resize = self.resize_transforms.to_deterministic()
+                # for img in batch: 
+                #     # img = img.numpy().astype(np.uint8)
+                #     # print(img.shape)
+                #     img = self.resize_transforms(img)
+                #     # print(img)
+                #     # print(img.shape)
+                #     # img = torch.moveaxis(img, 0, 2) # with HEDJitter wants [W,H,3], ColorJitter wants [3,W,H]
+                #     # print(img.shape)
+                #     img = self.color_transforms(img)
+                #     print(img.shape)
+                #     img = self.train_transforms(img)
+                    
+                #     # img = seq_img_d.augment_image(img)
+                #     img = self.val_transforms(img.copy())
+                #     out_batch.append(img)
+                # out_batch = torch.stack(out_batch)
+                out_batch = batch
+                
+                # ft = ft.view(-1, 512)
+                
+            else:
+                batch, label, (wsi_name, name_batch, patient) = self.get_data(t)
+                batch, _ = self.to_fixed_size_bag(batch, self.max_bag_size)
+                # label = Variable(Tensor(label))
+                # out_batch = []
+                # # seq_img_d = self.train_transforms.to_deterministic()
+                # # seq_img_resize = self.resize_transforms.to_deterministic()
+                # for img in batch: 
+                #     # img = img.numpy().astype(np.uint8)
+                #     # img = seq_img_resize(images=img)
+                #     img = self.resize_transforms(img)
+                #     img = np.moveaxis(img, 0, 2)
+                #     # img = img.numpy().astype(np.uint8)
+                #     # print(img.shape)
+                #     img = self.val_transforms(img)
+                #     out_batch.append(img)
+                # out_batch = torch.stack(out_batch)
+                out_batch = batch
+
+            return out_batch, label, (wsi_name , patient)
+
+if __name__ == '__main__':
+    
+    from pathlib import Path
+    import os
+    import time
+    from fast_tensor_dl import FastTensorDataLoader
+    from custom_resnet50 import resnet50_baseline
+    
+    
+
+    home = Path.cwd().parts[1]
+    # train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
+    data_root = f'/{home}/ylan/data/DeepGraft/224_256uM_annotated'
+    # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
+    # label_path = f'/{home}/ylan/DeepGraft/training_tables/split_debug.json'
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/dg_limit_5_split_PAS_HE_Jones_norm_rest_test.json'
+    # output_dir = f'/{data_root}/debug/augments'
+    # os.makedirs(output_dir, exist_ok=True)
+
+    n_classes = 2
+
+    dataset = JPGBagLoader(data_root, label_path=label_path, mode='train', n_classes=n_classes, cache=False)
+    # dataset = JPGBagLoader(data_root, label_path=label_path, mode='train', n_classes=n_classes, cache=False)
+
+    # print(dataset.get_labels(0))
+    # a = int(len(dataset)* 0.8)
+    # b = int(len(dataset) - a)
+    # train_data, valid_data = random_split(dataset, [a, b])
+    # print(dataset.dataset)
+    # a = int(len(dataset)* 0.8)
+    # b = int(len(dataset) - a)
+    # train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b])
+    # dl = FastTensorDataLoader(dataset, batch_size=1, shuffle=False)
+    dl = DataLoader(dataset, batch_size=2, num_workers=8, pin_memory=True)
+    # print(len(dl))
+    # dl = DataLoader(dataset, batch_size=1, sampler=ImbalancedDatasetSampler(dataset), num_workers=5)
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    scaler = torch.cuda.amp.GradScaler()
+
+    model_ft = resnet50_baseline(pretrained=True)
+    for param in model_ft.parameters():
+        param.requires_grad = False
+    model_ft.to(device)
+    
+    c = 0
+    label_count = [0] *n_classes
+    # print(len(dl))
+    start = time.time()
+    for item in tqdm(dl): 
+
+        # if c >= 10:
+        #     break
+        bag, label, (name, batch_names, patient) = item
+        print(bag.shape)
+        # print(name)
+        # print(batch_names)
+        # print(patient)
+        # print(len(batch_names))
+
+        print(label.shape)
+        # bag = bag.squeeze(0).float().to(device)
+        # label = label.to(device)
+        # with torch.cuda.amp.autocast():
+        #     output = model_ft(bag)
+        c += 1
+    end = time.time()
+
+    print('Bag Time: ', end-start)
\ No newline at end of file
diff --git a/code/datasets/monai_loader.py b/code/datasets/monai_loader.py
index 0bbe8033e7e15c82f1ffcafce7feddc7cc1ce770..4a3240e05a25b8cd6da3d19a097ee8cb740d7ed2 100644
--- a/code/datasets/monai_loader.py
+++ b/code/datasets/monai_loader.py
@@ -111,10 +111,11 @@ if __name__ == '__main__':
         data_list_key="training",
         base_dir=data_root,
     )
+    
 
     train_transform = Compose(
         [
-            LoadImaged(keys=["image"], reader=WSIReader, backend="cucim", dtype=np.uint8, level=0, image_only=True, num_workers=8),
+            LoadImaged(keys=["image"], reader=WSIReader, backend="cucim", dtype=np.uint8, level=1, image_only=True, num_workers=8),
             LabelEncodeIntegerGraded(keys=["label"], num_classes=num_classes),
             RandGridPatchd(
                 keys=["image"],
@@ -133,7 +134,7 @@ if __name__ == '__main__':
             ToTensord(keys=["image", "label"]),
         ]
     )
-    train_data_list = data['training']
+    training_list = data['training']
     # dataset_train = Dataset(data=training_list)
     dataset_train = Dataset(data=training_list, transform=train_transform)
     # persistent_dataset = PersistentDataset(data=training_list, transform=train_transform, cache_dir='/home/ylan/workspace/test')
diff --git a/code/datasets/simple_jpg_dataloader.py b/code/datasets/simple_jpg_dataloader.py
deleted file mode 100644
index c5e349f3b7f6426140fe8c9b026d3dc8932853bf..0000000000000000000000000000000000000000
--- a/code/datasets/simple_jpg_dataloader.py
+++ /dev/null
@@ -1,323 +0,0 @@
-# import pandas as pd
-
-import numpy as np
-import torch
-from torch import Tensor
-from torch.utils import data
-from torch.utils.data import random_split, DataLoader
-from torch.autograd import Variable
-from torch.nn.functional import one_hot
-import torch.utils.data as data_utils
-import torchvision.transforms as transforms
-import pandas as pd
-from sklearn.utils import shuffle
-from pathlib import Path
-from tqdm import tqdm
-from PIL import Image
-import cv2
-import json
-from imgaug import augmenters as iaa
-from torchsampler import ImbalancedDatasetSampler
-
-
-class JPGBagLoader(data_utils.Dataset):
-    def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=100, max_bag_size=1000, cache=False):
-        super().__init__()
-
-        self.data_info = []
-        self.data_cache = {}
-        self.slideLabelDict = {}
-        self.files = []
-        self.data_cache_size = data_cache_size
-        self.mode = mode
-        self.file_path = file_path
-        # self.csv_path = csv_path
-        self.label_path = label_path
-        self.n_classes = n_classes
-        self.max_bag_size = max_bag_size
-        self.min_bag_size = 50
-        self.empty_slides = []
-        self.corrupt_slides = []
-        self.cache = True
-        
-        # read labels and slide_path from csv
-        with open(self.label_path, 'r') as f:
-            temp_slide_label_dict = json.load(f)[mode]
-            print(len(temp_slide_label_dict))
-            for (x, y) in temp_slide_label_dict:
-                x = Path(x).stem 
-                # x_complete_path = Path(self.file_path)/Path(x)
-                for cohort in Path(self.file_path).iterdir():
-                    x_complete_path = Path(self.file_path) / cohort / 'BLOCKS' / Path(x)
-                    if x_complete_path.is_dir():
-                        if len(list(x_complete_path.iterdir())) > self.min_bag_size:
-                        # print(x_complete_path)
-                            self.slideLabelDict[x] = y
-                            self.files.append(x_complete_path)
-                        else: self.empty_slides.append(x_complete_path)
-        
-        home = Path.cwd().parts[1]
-        self.slide_patient_dict_path = f'/{home}/ylan/DeepGraft/training_tables/slide_patient_dict.json'
-        with open(self.slide_patient_dict_path, 'r') as f:
-            self.slide_patient_dict = json.load(f)
-
-        sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1")
-        sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2")
-        sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3")
-        sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4")
-        sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5")
-
-        self.train_transforms = iaa.Sequential([
-            iaa.AddToHueAndSaturation(value=(-30, 30), name="MyHSV"), #13
-            sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")),
-            iaa.Fliplr(0.5, name="MyFlipLR"),
-            iaa.Flipud(0.5, name="MyFlipUD"),
-            sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")),
-            # iaa.OneOf([
-            #     sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")),
-            #     sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")),
-            #     sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine"))
-            # ], name="MyOneOf")
-
-        ], name="MyAug")
-        self.val_transforms = transforms.Compose([
-            # 
-            transforms.ToTensor(),
-            transforms.Normalize(
-                mean=[0.485, 0.456, 0.406],
-                std=[0.229, 0.224, 0.225],
-            ),
-            # RangeNormalization(),
-        ])
-
-
-
-
-        self.features = []
-        self.labels = []
-        self.wsi_names = []
-        self.name_batches = []
-        self.patients = []
-        if self.cache:
-            if mode=='train':
-                seq_img_d = self.train_transforms.to_deterministic()
-                
-                # with tqdm(total=len(self.files)) as pbar:
-
-                for t in tqdm(self.files):
-                    batch, label, (wsi_name, name_batch, patient) = self.get_data(t)
-                    # print('label: ', label)
-                    out_batch = []
-                    for img in batch: 
-                        img = img.numpy().astype(np.uint8)
-                        img = seq_img_d.augment_image(img)
-                        img = self.val_transforms(img.copy())
-                        out_batch.append(img)
-                    # ft = ft.view(-1, 512)
-                    
-                    out_batch = torch.stack(out_batch)
-                    self.labels.append(label)
-                    self.features.append(out_batch)
-                    self.wsi_names.append(wsi_name)
-                    self.name_batches.append(name_batch)
-                    self.patients.append(patient)
-                        # pbar.update()
-            else: 
-                # with tqdm(total=len(self.file_path)) as pbar:
-                for t in tqdm(self.file_path):
-                    batch, label, (wsi_name, name_batch, patient) = self.get_data(t)
-                    out_batch = []
-                    for img in batch: 
-                        img = img.numpy().astype(np.uint8)
-                        img = self.val_transforms(img.copy())
-                        out_batch.append(img)
-                    # ft = ft.view(-1, 512)
-                    out_batch = torch.stack(out_batch)
-                    self.labels.append(label)
-                    self.features.append(out_batch)
-                    self.wsi_names.append(wsi_name)
-                    self.name_batches.append(name_batch)
-                    self.patients.append(patient)
-                        # pbar.update()
-        # print(self.get_bag_feats(self.train_path))
-        # self.r = np.random.RandomState(seed)
-
-        # self.num_in_train = 60000
-        # self.num_in_test = 10000
-
-        # if self.train:
-        #     self.train_bags_list, self.train_labels_list = self._create_bags()
-        # else:
-        #     self.test_bags_list, self.test_labels_list = self._create_bags()
-
-    def get_data(self, file_path):
-        
-        wsi_batch=[]
-        name_batch=[]
-        
-        for tile_path in Path(file_path).iterdir():
-            img = np.asarray(Image.open(tile_path)).astype(np.uint8)
-            img = torch.from_numpy(img)
-            wsi_batch.append(img)
-            name_batch.append(tile_path.stem)
-
-        wsi_batch = torch.stack(wsi_batch)
-
-        if wsi_batch.size(0) > self.max_bag_size:
-            wsi_batch, name_batch, _ = self.to_fixed_size_bag(wsi_batch, name_batch, self.max_bag_size)
-
-
-        wsi_batch, name_batch = self.data_dropout(wsi_batch, name_batch, drop_rate=0.1)
-
-        wsi_name = Path(file_path).stem
-        try:
-            label = self.slideLabelDict[wsi_name]
-        except KeyError:
-            print(f'{wsi_name} is not included in label file {self.label_path}')
-
-        try:
-            patient = self.slide_patient_dict[wsi_name]
-        except KeyError:
-            print(f'{wsi_name} is not included in label file {self.slide_patient_dict_path}')
-
-        return wsi_batch, label, (wsi_name, name_batch, patient)
-    
-    def get_labels(self, indices):
-        return [self.labels[i] for i in indices]
-
-
-    def to_fixed_size_bag(self, bag, names, bag_size: int = 512):
-
-        #duplicate bag instances unitl 
-
-        bag_idxs = torch.randperm(bag.shape[0])[:bag_size]
-        bag_samples = bag[bag_idxs]
-        name_samples = [names[i] for i in bag_idxs]
-        # bag_sample_names = [bag_names[i] for i in bag_idxs]
-        # q, r  = divmod(bag_size, bag_samples.shape[0])
-        # if q > 0:
-        #     bag_samples = torch.cat([bag_samples]*q, 0)
-
-        # self_padded = torch.cat([bag_samples, bag_samples[:r,:, :, :]])
-
-        # zero-pad if we don't have enough samples
-        # zero_padded = torch.cat((bag_samples,
-        #                         torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
-
-        return bag_samples, name_samples, min(bag_size, len(bag))
-
-    def data_dropout(self, bag, batch_names, drop_rate):
-        bag_size = bag.shape[0]
-        bag_idxs = torch.randperm(bag_size)[:int(bag_size*(1-drop_rate))]
-        bag_samples = bag[bag_idxs]
-        name_samples = [batch_names[i] for i in bag_idxs]
-
-        return bag_samples, name_samples
-
-    def __len__(self):
-        return len(self.files)
-
-    def __getitem__(self, index):
-
-        if self.cache:
-            label = self.labels[index]
-            wsi = self.features[index]
-            label = int(label)
-            wsi_name = self.wsi_names[index]
-            name_batch = self.name_batches[index]
-            patient = self.patients[index]
-            # feats = Variable(Tensor(feats))
-            return wsi, label, (wsi_name, name_batch, patient)
-        else:
-            if self.mode=='train':
-                batch, label, (wsi_name, name_batch, patient) = self.get_data(self.files[index])
-                # label = Variable(Tensor(label))
-
-                # wsi = Variable(Tensor(wsi_batch))
-                out_batch = []
-                seq_img_d = self.train_transforms.to_deterministic()
-                for img in batch: 
-                    img = img.numpy().astype(np.uint8)
-                    # img = seq_img_d.augment_image(img)
-                    img = self.val_transforms(img.copy())
-                    out_batch.append(img)
-                out_batch = torch.stack(out_batch)
-                # ft = ft.view(-1, 512)
-                
-            else:
-                batch, label, (wsi_name, name_batch, patient) = self.get_data(self.files[index])
-                label = Variable(Tensor(label))
-                out_batch = []
-                seq_img_d = self.train_transforms.to_deterministic()
-                for img in batch: 
-                    img = img.numpy().astype(np.uint8)
-                    img = self.val_transforms(img.copy())
-                    out_batch.append(img)
-                out_batch = torch.stack(out_batch)
-
-            return out_batch, label, (wsi_name, name_batch, patient)
-
-if __name__ == '__main__':
-    
-    from pathlib import Path
-    import os
-    import time
-    from fast_tensor_dl import FastTensorDataLoader
-    from custom_resnet50 import resnet50_baseline
-    
-    
-
-    home = Path.cwd().parts[1]
-    train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
-    data_root = f'/{home}/ylan/data/DeepGraft/224_128um_v2'
-    # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
-    # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
-    label_path = f'/{home}/ylan/DeepGraft/training_tables/split_debug.json'
-    # label_path = f'/{home}/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_norm_rest.json'
-    output_dir = f'/{data_root}/debug/augments'
-    os.makedirs(output_dir, exist_ok=True)
-
-    n_classes = 2
-
-    dataset = JPGBagLoader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes)
-
-    # print(dataset.get_labels(0))
-    a = int(len(dataset)* 0.8)
-    b = int(len(dataset) - a)
-    train_data, valid_data = random_split(dataset, [a, b])
-    # print(dataset.dataset)
-    # a = int(len(dataset)* 0.8)
-    # b = int(len(dataset) - a)
-    # train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b])
-    # dl = FastTensorDataLoader(dataset, batch_size=1, shuffle=False)
-    dl = DataLoader(train_data, batch_size=1, num_workers=8, sampler=ImbalancedDatasetSampler(train_data), pin_memory=True)
-    # print(len(dl))
-    # dl = DataLoader(dataset, batch_size=1, sampler=ImbalancedDatasetSampler(dataset), num_workers=5)
-    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
-    scaler = torch.cuda.amp.GradScaler()
-
-    model_ft = resnet50_baseline(pretrained=True)
-    for param in model_ft.parameters():
-        param.requires_grad = False
-    model_ft.to(device)
-    
-    c = 0
-    label_count = [0] *n_classes
-    # print(len(dl))
-    start = time.time()
-    for item in tqdm(dl): 
-
-        # if c >= 10:
-        #     break
-        bag, label, (name, batch_names, patient) = item
-        # print(bag.shape)
-        # print(len(batch_names))
-        print(label)
-        bag = bag.squeeze(0).float().to(device)
-        label = label.to(device)
-        with torch.cuda.amp.autocast():
-            output = model_ft(bag)
-        c += 1
-    end = time.time()
-
-    print('Bag Time: ', end-start)
\ No newline at end of file
diff --git a/code/datasets/utils/__init__.py b/code/datasets/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/code/datasets/utils/__pycache__/__init__.cpython-39.pyc b/code/datasets/utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19391f5fd1eeca742d1e73b5108a9cf52b2c4e60
Binary files /dev/null and b/code/datasets/utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/code/datasets/utils/__pycache__/myTransforms.cpython-39.pyc b/code/datasets/utils/__pycache__/myTransforms.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1a33858df412d4143659f5a98c36a3b322b6180
Binary files /dev/null and b/code/datasets/utils/__pycache__/myTransforms.cpython-39.pyc differ
diff --git a/code/datasets/utils/myTransforms.py b/code/datasets/utils/myTransforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ac760a570155638a2c3f6b7689be29eceb8d25b
--- /dev/null
+++ b/code/datasets/utils/myTransforms.py
@@ -0,0 +1,1426 @@
+from __future__ import division
+import torch
+import math
+import sys
+import random
+try:
+    import accimage
+except ImportError:
+    accimage = None
+import numpy as np
+import numbers
+import types
+import collections
+import warnings
+
+from PIL import Image, ImageFilter
+from skimage import color
+import cv2
+from scipy.ndimage.interpolation import map_coordinates
+from scipy.ndimage.filters import gaussian_filter
+from torchvision.transforms import functional as F
+import torchvision.transforms as T
+
+if sys.version_info < (3, 3):
+    Sequence = collections.Sequence
+    Iterable = collections.Iterable
+else:
+    Sequence = collections.abc.Sequence
+    Iterable = collections.abc.Iterable
+
+
+# 2020.4.19 Chaoyang
+# add new augmentation class *HEDJitter* for HED color space perturbation.
+# add new random rotation class *AutoRandomRotation* only for 0, 90,180,270 rotation.
+# delete Scale class because it is inapplicable now.
+# 2020.4.20 Chaoyang
+# add annotation for class *RandomAffine*, how to use it. line 1040 -- 1046
+# add new augmentation class *RandomGaussBlur* for gaussian blurring.
+# add new augmentation class *RandomAffineCV2* for affine transformation by cv2, which can \
+# set BORDER_REFLECT for the area outside the transform in the output image.
+# add new augmentation class *RandomElastic* for elastic transformation by cv2
+__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "CenterCrop", "Pad",
+           "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
+           "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
+           "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
+           "RandomPerspective", "RandomErasing",
+           "HEDJitter", "AutoRandomRotation", "RandomGaussBlur", "RandomAffineCV2", "RandomElastic"]
+
+_pil_interpolation_to_str = {
+    Image.NEAREST: 'T.InterpolationMode.NEAREST',
+    Image.BILINEAR: 'T.InterpolationMode.BILINEAR',
+    Image.BICUBIC: 'T.InterpolationMode.BICUBIC',
+    Image.LANCZOS: 'T.InterpolationMode.LANCZOS',
+    Image.HAMMING: 'T.InterpolationMode.HAMMING',
+    Image.BOX: 'T.InterpolationMode.BOX',
+}
+# _pil_interpolation_to_str = {
+#     Image.NEAREST: 'PIL.Image.NEAREST',
+#     Image.BILINEAR: 'PIL.Image.BILINEAR',
+#     Image.BICUBIC: 'PIL.Image.BICUBIC',
+#     Image.LANCZOS: 'PIL.Image.LANCZOS',
+#     Image.HAMMING: 'PIL.Image.HAMMING',
+#     Image.BOX: 'PIL.Image.BOX',
+# }
+
+
+def _get_image_size(img):
+    if F._is_pil_image(img):
+        return img.size
+    elif isinstance(img, torch.Tensor) and img.dim() > 2:
+        return img.shape[-2:][::-1]
+    else:
+        raise TypeError("Unexpected type {}".format(type(img)))
+
+
+class Compose(object):
+    """Composes several transforms together.
+    Args:
+        transforms (list of ``Transform`` objects): list of transforms to compose.
+    Example:
+        # >>> transforms.Compose([
+        # >>>     transforms.CenterCrop(10),
+        # >>>     transforms.ToTensor(),
+        # >>> ])
+    """
+
+    def __init__(self, transforms):
+        self.transforms = transforms
+
+    def __call__(self, img):
+        for t in self.transforms:
+            img = t(img)
+        return img
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + '('
+        for t in self.transforms:
+            format_string += '\n'
+            format_string += '    {0}'.format(t)
+        format_string += '\n)'
+        return format_string
+
+
+class ToTensor(object):
+    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
+    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
+    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
+    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
+    or if the numpy.ndarray has dtype = np.uint8
+    In the other cases, tensors are returned without scaling.
+    """
+
+    def __call__(self, pic):
+        """
+        Args:
+            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
+        Returns:
+            Tensor: Converted image.
+        """
+        return F.to_tensor(pic)
+
+    def __repr__(self):
+        return self.__class__.__name__ + '()'
+
+
+class ToPILImage(object):
+    """Convert a tensor or an ndarray to PIL Image.
+    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
+    H x W x C to a PIL Image while preserving the value range.
+    Args:
+        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
+            If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
+             - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
+             - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
+             - If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
+             - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``,
+               ``short``).
+    .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
+    """
+    def __init__(self, mode=None):
+        self.mode = mode
+
+    def __call__(self, pic):
+        """
+        Args:
+            pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
+        Returns:
+            PIL Image: Image converted to PIL Image.
+        """
+        return F.to_pil_image(pic, self.mode)
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + '('
+        if self.mode is not None:
+            format_string += 'mode={0}'.format(self.mode)
+        format_string += ')'
+        return format_string
+
+
+class Normalize(object):
+    """Normalize a tensor image with mean and standard deviation.
+    Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
+    will normalize each channel of the input ``torch.*Tensor`` i.e.
+    ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
+    .. note::
+        This transform acts out of place, i.e., it does not mutates the input tensor.
+    Args:
+        mean (sequence): Sequence of means for each channel.
+        std (sequence): Sequence of standard deviations for each channel.
+        inplace(bool,optional): Bool to make this operation in-place.
+    """
+
+    def __init__(self, mean, std, inplace=False):
+        self.mean = mean
+        self.std = std
+        self.inplace = inplace
+
+    def __call__(self, tensor):
+        """
+        Args:
+            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
+        Returns:
+            Tensor: Normalized Tensor image.
+        """
+        return F.normalize(tensor, self.mean, self.std, self.inplace)
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
+
+
+class Resize(object):
+    """Resize the input PIL Image to the given size.
+    Args:
+        size (sequence or int): Desired output size. If size is a sequence like
+            (h, w), output size will be matched to this. If size is an int,
+            smaller edge of the image will be matched to this number.
+            i.e, if height > width, then image will be rescaled to
+            (size * height / width, size)
+        interpolation (int, optional): Desired interpolation. Default is
+            ``PIL.Image.BILINEAR``
+    """
+
+    def __init__(self, size, interpolation=Image.BILINEAR):
+        assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
+        self.size = size
+        self.interpolation = interpolation
+
+    def __call__(self, img):
+        """
+        Args:
+            img (PIL Image): Image to be scaled.
+        Returns:
+            PIL Image: Rescaled image.
+        """
+        return F.resize(img, self.size, self.interpolation)
+
+    def __repr__(self):
+        interpolate_str = _pil_interpolation_to_str[self.interpolation]
+        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
+
+
+class CenterCrop(object):
+    """Crops the given PIL Image at the center.
+    Args:
+        size (sequence or int): Desired output size of the crop. If size is an
+            int instead of sequence like (h, w), a square crop (size, size) is
+            made.
+    """
+
+    def __init__(self, size):
+        if isinstance(size, numbers.Number):
+            self.size = (int(size), int(size))
+        else:
+            self.size = size
+
+    def __call__(self, img):
+        """
+        Args:
+            img (PIL Image): Image to be cropped.
+        Returns:
+            PIL Image: Cropped image.
+        """
+        return F.center_crop(img, self.size)
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(size={0})'.format(self.size)
+
+
+class Pad(object):
+    """Pad the given PIL Image on all sides with the given "pad" value.
+    Args:
+        padding (int or tuple): Padding on each border. If a single int is provided this
+            is used to pad all borders. If tuple of length 2 is provided this is the padding
+            on left/right and top/bottom respectively. If a tuple of length 4 is provided
+            this is the padding for the left, top, right and bottom borders
+            respectively.
+        fill (int or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
+            length 3, it is used to fill R, G, B channels respectively.
+            This value is only used when the padding_mode is constant
+        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
+            Default is constant.
+            - constant: pads with a constant value, this value is specified with fill
+            - edge: pads with the last value at the edge of the image
+            - reflect: pads with reflection of image without repeating the last value on the edge
+                For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
+                will result in [3, 2, 1, 2, 3, 4, 3, 2]
+            - symmetric: pads with reflection of image repeating the last value on the edge
+                For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
+                will result in [2, 1, 1, 2, 3, 4, 4, 3]
+    """
+
+    def __init__(self, padding, fill=0, padding_mode='constant'):
+        assert isinstance(padding, (numbers.Number, tuple))
+        assert isinstance(fill, (numbers.Number, str, tuple))
+        assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
+        if isinstance(padding, Sequence) and len(padding) not in [2, 4]:
+            raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
+                             "{} element tuple".format(len(padding)))
+
+        self.padding = padding
+        self.fill = fill
+        self.padding_mode = padding_mode
+
+    def __call__(self, img):
+        """
+        Args:
+            img (PIL Image): Image to be padded.
+        Returns:
+            PIL Image: Padded image.
+        """
+        return F.pad(img, self.padding, self.fill, self.padding_mode)
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
+            format(self.padding, self.fill, self.padding_mode)
+
+
+class Lambda(object):
+    """Apply a user-defined lambda as a transform.
+    Args:
+        lambd (function): Lambda/function to be used for transform.
+    """
+
+    def __init__(self, lambd):
+        assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
+        self.lambd = lambd
+
+    def __call__(self, img):
+        return self.lambd(img)
+
+    def __repr__(self):
+        return self.__class__.__name__ + '()'
+
+
+class RandomTransforms(object):
+    """Base class for a list of transformations with randomness
+    Args:
+        transforms (list or tuple): list of transformations
+    """
+
+    def __init__(self, transforms):
+        assert isinstance(transforms, (list, tuple))
+        self.transforms = transforms
+
+    def __call__(self, *args, **kwargs):
+        raise NotImplementedError()
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + '('
+        for t in self.transforms:
+            format_string += '\n'
+            format_string += '    {0}'.format(t)
+        format_string += '\n)'
+        return format_string
+
+
+class RandomApply(RandomTransforms):
+    """Apply randomly a list of transformations with a given probability
+    Args:
+        transforms (list or tuple): list of transformations
+        p (float): probability
+    """
+
+    def __init__(self, transforms, p=0.5):
+        super(RandomApply, self).__init__(transforms)
+        self.p = p
+
+    def __call__(self, img):
+        if self.p < random.random():
+            return img
+        for t in self.transforms:
+            img = t(img)
+        return img
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + '('
+        format_string += '\n    p={}'.format(self.p)
+        for t in self.transforms:
+            format_string += '\n'
+            format_string += '    {0}'.format(t)
+        format_string += '\n)'
+        return format_string
+
+
+class RandomOrder(RandomTransforms):
+    """Apply a list of transformations in a random order
+    """
+    def __call__(self, img):
+        order = list(range(len(self.transforms)))
+        random.shuffle(order)
+        for i in order:
+            img = self.transforms[i](img)
+        return img
+
+
+class RandomChoice(RandomTransforms):
+    """Apply single transformation randomly picked from a list
+    """
+    def __call__(self, img):
+        t = random.choice(self.transforms)
+        return t(img)
+
+
+class RandomCrop(object):
+    """Crop the given PIL Image at a random location.
+    Args:
+        size (sequence or int): Desired output size of the crop. If size is an
+            int instead of sequence like (h, w), a square crop (size, size) is
+            made.
+        padding (int or sequence, optional): Optional padding on each border
+            of the image. Default is None, i.e no padding. If a sequence of length
+            4 is provided, it is used to pad left, top, right, bottom borders
+            respectively. If a sequence of length 2 is provided, it is used to
+            pad left/right, top/bottom borders, respectively.
+        pad_if_needed (boolean): It will pad the image if smaller than the
+            desired size to avoid raising an exception. Since cropping is done
+            after padding, the padding seems to be done at a random offset.
+        fill: Pixel fill value for constant fill. Default is 0. If a tuple of
+            length 3, it is used to fill R, G, B channels respectively.
+            This value is only used when the padding_mode is constant
+        padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
+             - constant: pads with a constant value, this value is specified with fill
+             - edge: pads with the last value on the edge of the image
+             - reflect: pads with reflection of image (without repeating the last value on the edge)
+                padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
+                will result in [3, 2, 1, 2, 3, 4, 3, 2]
+             - symmetric: pads with reflection of image (repeating the last value on the edge)
+                padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
+                will result in [2, 1, 1, 2, 3, 4, 4, 3]
+    """
+
+    def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
+        if isinstance(size, numbers.Number):
+            self.size = (int(size), int(size))
+        else:
+            self.size = size
+        self.padding = padding
+        self.pad_if_needed = pad_if_needed
+        self.fill = fill
+        self.padding_mode = padding_mode
+
+    @staticmethod
+    def get_params(img, output_size):
+        """Get parameters for ``crop`` for a random crop.
+        Args:
+            img (PIL Image): Image to be cropped.
+            output_size (tuple): Expected output size of the crop.
+        Returns:
+            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
+        """
+        w, h = _get_image_size(img)
+        th, tw = output_size
+        if w == tw and h == th:
+            return 0, 0, h, w
+
+        i = random.randint(0, h - th)
+        j = random.randint(0, w - tw)
+        return i, j, th, tw
+
+    def __call__(self, img):
+        """
+        Args:
+            img (PIL Image): Image to be cropped.
+        Returns:
+            PIL Image: Cropped image.
+        """
+        if self.padding is not None:
+            img = F.pad(img, self.padding, self.fill, self.padding_mode)
+
+        # pad the width if needed
+        if self.pad_if_needed and img.size[0] < self.size[1]:
+            img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
+        # pad the height if needed
+        if self.pad_if_needed and img.size[1] < self.size[0]:
+            img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
+
+        i, j, h, w = self.get_params(img, self.size)
+
+        return F.crop(img, i, j, h, w)
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
+
+
+class RandomHorizontalFlip(object):
+    """Horizontally flip the given PIL Image randomly with a given probability.
+    Args:
+        p (float): probability of the image being flipped. Default value is 0.5
+    """
+
+    def __init__(self, p=0.5):
+        self.p = p
+
+    def __call__(self, img):
+        """
+        Args:
+            img (PIL Image): Image to be flipped.
+        Returns:
+            PIL Image: Randomly flipped image.
+        """
+        if random.random() < self.p:
+            return F.hflip(img)
+        return img
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(p={})'.format(self.p)
+
+
+class RandomVerticalFlip(object):
+    """Vertically flip the given PIL Image randomly with a given probability.
+    Args:
+        p (float): probability of the image being flipped. Default value is 0.5
+    """
+
+    def __init__(self, p=0.5):
+        self.p = p
+
+    def __call__(self, img):
+        """
+        Args:
+            img (PIL Image): Image to be flipped.
+        Returns:
+            PIL Image: Randomly flipped image.
+        """
+        if random.random() < self.p:
+            return F.vflip(img)
+        return img
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(p={})'.format(self.p)
+
+
+class RandomPerspective(object):
+    """Performs Perspective transformation of the given PIL Image randomly with a given probability.
+    Args:
+        interpolation : Default- Image.BICUBIC
+        p (float): probability of the image being perspectively transformed. Default value is 0.5
+        distortion_scale(float): it controls the degree of distortion and ranges from 0 to 1. Default value is 0.5.
+    """
+
+    def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC):
+        self.p = p
+        self.interpolation = interpolation
+        self.distortion_scale = distortion_scale
+
+    def __call__(self, img):
+        """
+        Args:
+            img (PIL Image): Image to be Perspectively transformed.
+        Returns:
+            PIL Image: Random perspectivley transformed image.
+        """
+        if not F._is_pil_image(img):
+            raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+
+        if random.random() < self.p:
+            width, height = img.size
+            startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
+            return F.perspective(img, startpoints, endpoints, self.interpolation)
+        return img
+
+    @staticmethod
+    def get_params(width, height, distortion_scale):
+        """Get parameters for ``perspective`` for a random perspective transform.
+        Args:
+            width : width of the image.
+            height : height of the image.
+        Returns:
+            List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
+            List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
+        """
+        half_height = int(height / 2)
+        half_width = int(width / 2)
+        topleft = (random.randint(0, int(distortion_scale * half_width)),
+                   random.randint(0, int(distortion_scale * half_height)))
+        topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
+                    random.randint(0, int(distortion_scale * half_height)))
+        botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
+                    random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
+        botleft = (random.randint(0, int(distortion_scale * half_width)),
+                   random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
+        startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
+        endpoints = [topleft, topright, botright, botleft]
+        return startpoints, endpoints
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(p={})'.format(self.p)
+
+
+class RandomResizedCrop(object):
+    """Crop the given PIL Image to random size and aspect ratio.
+    A crop of random size (default: of 0.08 to 1.0) of the original size and a random
+    aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
+    is finally resized to given size.
+    This is popularly used to train the Inception networks.
+    Args:
+        size: expected output size of each edge
+        scale: range of size of the origin size cropped
+        ratio: range of aspect ratio of the origin aspect ratio cropped
+        interpolation: Default: PIL.Image.BILINEAR
+    """
+
+    def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
+        if isinstance(size, tuple):
+            self.size = size
+        else:
+            self.size = (size, size)
+        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
+            warnings.warn("range should be of kind (min, max)")
+
+        self.interpolation = interpolation
+        self.scale = scale
+        self.ratio = ratio
+
+    @staticmethod
+    def get_params(img, scale, ratio):
+        """Get parameters for ``crop`` for a random sized crop.
+        Args:
+            img (PIL Image): Image to be cropped.
+            scale (tuple): range of size of the origin size cropped
+            ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
+        Returns:
+            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
+                sized crop.
+        """
+        width, height = _get_image_size(img)
+        area = height * width
+
+        for attempt in range(10):
+            target_area = random.uniform(*scale) * area
+            log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
+            aspect_ratio = math.exp(random.uniform(*log_ratio))
+
+            w = int(round(math.sqrt(target_area * aspect_ratio)))
+            h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+            if 0 < w <= width and 0 < h <= height:
+                i = random.randint(0, height - h)
+                j = random.randint(0, width - w)
+                return i, j, h, w
+
+        # Fallback to central crop
+        in_ratio = float(width) / float(height)
+        if (in_ratio < min(ratio)):
+            w = width
+            h = int(round(w / min(ratio)))
+        elif (in_ratio > max(ratio)):
+            h = height
+            w = int(round(h * max(ratio)))
+        else:  # whole image
+            w = width
+            h = height
+        i = (height - h) // 2
+        j = (width - w) // 2
+        return i, j, h, w
+
+    def __call__(self, img):
+        """
+        Args:
+            img (PIL Image): Image to be cropped and resized.
+        Returns:
+            PIL Image: Randomly cropped and resized image.
+        """
+        i, j, h, w = self.get_params(img, self.scale, self.ratio)
+        return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
+
+    def __repr__(self):
+        interpolate_str = _pil_interpolation_to_str[self.interpolation]
+        format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
+        format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
+        format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
+        format_string += ', interpolation={0})'.format(interpolate_str)
+        return format_string
+
+
+class RandomSizedCrop(RandomResizedCrop):
+    """
+    Note: This transform is deprecated in favor of RandomResizedCrop.
+    """
+    def __init__(self, *args, **kwargs):
+        warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
+                      "please use transforms.RandomResizedCrop instead.")
+        super(RandomSizedCrop, self).__init__(*args, **kwargs)
+
+
+class FiveCrop(object):
+    """Crop the given PIL Image into four corners and the central crop
+    .. Note::
+         This transform returns a tuple of images and there may be a mismatch in the number of
+         inputs and targets your Dataset returns. See below for an example of how to deal with
+         this.
+    Args:
+         size (sequence or int): Desired output size of the crop. If size is an ``int``
+            instead of sequence like (h, w), a square crop of size (size, size) is made.
+    Example:
+         # >>> transform = Compose([
+         # >>>    FiveCrop(size), # this is a list of PIL Images
+         # >>>    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
+         # >>> ])
+         # >>> #In your test loop you can do the following:
+         # >>> input, target = batch # input is a 5d tensor, target is 2d
+         # >>> bs, ncrops, c, h, w = input.size()
+         # >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
+         # >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
+    """
+
+    def __init__(self, size):
+        self.size = size
+        if isinstance(size, numbers.Number):
+            self.size = (int(size), int(size))
+        else:
+            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
+            self.size = size
+
+    def __call__(self, img):
+        return F.five_crop(img, self.size)
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(size={0})'.format(self.size)
+
+
+class TenCrop(object):
+    """Crop the given PIL Image into four corners and the central crop plus the flipped version of
+    these (horizontal flipping is used by default)
+    .. Note::
+         This transform returns a tuple of images and there may be a mismatch in the number of
+         inputs and targets your Dataset returns. See below for an example of how to deal with
+         this.
+    Args:
+        size (sequence or int): Desired output size of the crop. If size is an
+            int instead of sequence like (h, w), a square crop (size, size) is
+            made.
+        vertical_flip (bool): Use vertical flipping instead of horizontal
+    Example:
+         # >>> transform = Compose([
+         # >>>    TenCrop(size), # this is a list of PIL Images
+         # >>>    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
+         # >>> ])
+         # >>> #In your test loop you can do the following:
+         # >>> input, target = batch # input is a 5d tensor, target is 2d
+         # >>> bs, ncrops, c, h, w = input.size()
+         # >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
+         # >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
+    """
+
+    def __init__(self, size, vertical_flip=False):
+        self.size = size
+        if isinstance(size, numbers.Number):
+            self.size = (int(size), int(size))
+        else:
+            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
+            self.size = size
+        self.vertical_flip = vertical_flip
+
+    def __call__(self, img):
+        return F.ten_crop(img, self.size, self.vertical_flip)
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
+
+
+class LinearTransformation(object):
+    """Transform a tensor image with a square transformation matrix and a mean_vector computed
+    offline.
+    Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
+    subtract mean_vector from it which is then followed by computing the dot
+    product with the transformation matrix and then reshaping the tensor to its
+    original shape.
+    Applications:
+        whitening transformation: Suppose X is a column vector zero-centered data.
+        Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
+        perform SVD on this matrix and pass it as transformation_matrix.
+    Args:
+        transformation_matrix (Tensor): tensor [D x D], D = C x H x W
+        mean_vector (Tensor): tensor [D], D = C x H x W
+    """
+
+    def __init__(self, transformation_matrix, mean_vector):
+        if transformation_matrix.size(0) != transformation_matrix.size(1):
+            raise ValueError("transformation_matrix should be square. Got " +
+                             "[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
+
+        if mean_vector.size(0) != transformation_matrix.size(0):
+            raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) +
+                             " as any one of the dimensions of the transformation_matrix [{} x {}]"
+                             .format(transformation_matrix.size()))
+
+        self.transformation_matrix = transformation_matrix
+        self.mean_vector = mean_vector
+
+    def __call__(self, tensor):
+        """
+        Args:
+            tensor (Tensor): Tensor image of size (C, H, W) to be whitened.
+        Returns:
+            Tensor: Transformed image.
+        """
+        if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0):
+            raise ValueError("tensor and transformation matrix have incompatible shape." +
+                             "[{} x {} x {}] != ".format(*tensor.size()) +
+                             "{}".format(self.transformation_matrix.size(0)))
+        flat_tensor = tensor.view(1, -1) - self.mean_vector
+        transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
+        tensor = transformed_tensor.view(tensor.size())
+        return tensor
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + '(transformation_matrix='
+        format_string += (str(self.transformation_matrix.tolist()) + ')')
+        format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')')
+        return format_string
+
+
+class ColorJitter(object):
+    """Randomly change the brightness, contrast and saturation of an image.
+    Args:
+        brightness (float or tuple of float (min, max)): How much to jitter brightness.
+            brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
+            or the given [min, max]. Should be non negative numbers.
+        contrast (float or tuple of float (min, max)): How much to jitter contrast.
+            contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
+            or the given [min, max]. Should be non negative numbers.
+        saturation (float or tuple of float (min, max)): How much to jitter saturation.
+            saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
+            or the given [min, max]. Should be non negative numbers.
+        hue (float or tuple of float (min, max)): How much to jitter hue.
+            hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
+            Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
+    """
+    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
+        self.brightness = self._check_input(brightness, 'brightness')
+        self.contrast = self._check_input(contrast, 'contrast')
+        self.saturation = self._check_input(saturation, 'saturation')
+        self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
+                                     clip_first_on_zero=False)
+
+    def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
+        if isinstance(value, numbers.Number):
+            if value < 0:
+                raise ValueError("If {} is a single number, it must be non negative.".format(name))
+            value = [center - value, center + value]
+            if clip_first_on_zero:
+                value[0] = max(value[0], 0)
+        elif isinstance(value, (tuple, list)) and len(value) == 2:
+            if not bound[0] <= value[0] <= value[1] <= bound[1]:
+                raise ValueError("{} values should be between {}".format(name, bound))
+        else:
+            raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
+
+        # if value is 0 or (1., 1.) for brightness/contrast/saturation
+        # or (0., 0.) for hue, do nothing
+        if value[0] == value[1] == center:
+            value = None
+        return value
+
+    @staticmethod
+    def get_params(brightness, contrast, saturation, hue):
+        """Get a randomized transform to be applied on image.
+        Arguments are same as that of __init__.
+        Returns:
+            Transform which randomly adjusts brightness, contrast and
+            saturation in a random order.
+        """
+        transforms = []
+
+        if brightness is not None:
+            brightness_factor = random.uniform(brightness[0], brightness[1])
+            transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
+
+        if contrast is not None:
+            contrast_factor = random.uniform(contrast[0], contrast[1])
+            transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
+
+        if saturation is not None:
+            saturation_factor = random.uniform(saturation[0], saturation[1])
+            transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
+
+        if hue is not None:
+            hue_factor = random.uniform(hue[0], hue[1])
+            transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
+
+        random.shuffle(transforms)
+        transform = Compose(transforms)
+
+        return transform
+
+    def __call__(self, img):
+        """
+        Args:
+            img (PIL Image): Input image.
+        Returns:
+            PIL Image: Color jittered image.
+        """
+        transform = self.get_params(self.brightness, self.contrast,
+                                    self.saturation, self.hue)
+        return transform(img)
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + '('
+        format_string += 'brightness={0}'.format(self.brightness)
+        format_string += ', contrast={0}'.format(self.contrast)
+        format_string += ', saturation={0}'.format(self.saturation)
+        format_string += ', hue={0})'.format(self.hue)
+        return format_string
+
+
+class RandomRotation(object):
+    """Rotate the image by angle.
+    Args:
+        degrees (sequence or float or int): Range of degrees to select from.
+            If degrees is a number instead of sequence like (min, max), the range of degrees
+            will be (-degrees, +degrees).
+        resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
+            An optional resampling filter. See `filters`_ for more information.
+            If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
+        expand (bool, optional): Optional expansion flag.
+            If true, expands the output to make it large enough to hold the entire rotated image.
+            If false or omitted, make the output image the same size as the input image.
+            Note that the expand flag assumes rotation around the center and no translation.
+        center (2-tuple, optional): Optional center of rotation.
+            Origin is the upper left corner.
+            Default is the center of the image.
+        fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
+            If int, it is used for all channels respectively.
+    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
+    """
+
+    def __init__(self, degrees, resample=False, expand=False, center=None, fill=0):
+        if isinstance(degrees, numbers.Number):
+            if degrees < 0:
+                raise ValueError("If degrees is a single number, it must be positive.")
+            self.degrees = (-degrees, degrees)
+        else:
+            if len(degrees) != 2:
+                raise ValueError("If degrees is a sequence, it must be of len 2.")
+            self.degrees = degrees
+
+        self.resample = resample
+        self.expand = expand
+        self.center = center
+        self.fill = fill
+
+    @staticmethod
+    def get_params(degrees):
+        """Get parameters for ``rotate`` for a random rotation.
+        Returns:
+            sequence: params to be passed to ``rotate`` for random rotation.
+        """
+        angle = random.uniform(degrees[0], degrees[1])
+
+        return angle
+
+    def __call__(self, img):
+        """
+        Args:
+            img (PIL Image): Image to be rotated.
+        Returns:
+            PIL Image: Rotated image.
+        """
+
+        angle = self.get_params(self.degrees)
+
+        return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill)
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
+        format_string += ', resample={0}'.format(self.resample)
+        format_string += ', expand={0}'.format(self.expand)
+        if self.center is not None:
+            format_string += ', center={0}'.format(self.center)
+        format_string += ')'
+        return format_string
+
+
+class RandomAffine(object):
+    """Random affine transformation of the image keeping center invariant
+    Args:
+        degrees (sequence or float or int): Range of degrees to select from.
+            If degrees is a number instead of sequence like (min, max), the range of degrees
+            will be (-degrees, +degrees). Set to 0 to deactivate rotations.
+        translate (tuple, optional): tuple of maximum absolute fraction for horizontal
+            and vertical translations. For example translate=(a, b), then horizontal shift
+            is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
+            randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
+        scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
+            randomly sampled from the range a <= scale <= b. Will keep original scale by default.
+        shear (sequence or float or int, optional): Range of degrees to select from.
+            If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
+            will be apllied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
+            range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values,
+            a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
+            Will not apply shear by default
+        resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
+            An optional resampling filter. See `filters`_ for more information.
+            If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
+        fillcolor (tuple or int): Optional fill color (Tuple for RGB Image And int for grayscale) for the area
+            outside the transform in the output image.(Pillow>=5.0.0)
+    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
+    """
+    # degree: rotate for the image; \in [-180, 180]; 旋转
+    # translate: translation for the image, \in [0,1] 平移
+    # scale: scale the image with center invariant, better \in (0,2] 放缩
+    # shear: shear the image with dx or dy, w\in [-180, 180] 扭曲
+    # eg.
+    # preprocess1 = myTransforms.RandomAffine(degrees=0, translate=[0, 0.2], scale=[0.8, 1.2],
+    #                                        shear=[-10, 10, -10, 10], fillcolor=(228, 218, 218))
+    def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
+        if isinstance(degrees, numbers.Number):
+            if degrees < 0:
+                raise ValueError("If degrees is a single number, it must be positive.")
+            self.degrees = (-degrees, degrees)
+        else:
+            assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
+                "degrees should be a list or tuple and it must be of length 2."
+            self.degrees = degrees
+
+        if translate is not None:
+            assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
+                "translate should be a list or tuple and it must be of length 2."
+            for t in translate:
+                if not (0.0 <= t <= 1.0):
+                    raise ValueError("translation values should be between 0 and 1")
+        self.translate = translate
+
+        if scale is not None:
+            assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
+                "scale should be a list or tuple and it must be of length 2."
+            for s in scale:
+                if s <= 0:
+                    raise ValueError("scale values should be positive")
+        self.scale = scale
+
+        if shear is not None:
+            if isinstance(shear, numbers.Number):
+                if shear < 0:
+                    raise ValueError("If shear is a single number, it must be positive.")
+                self.shear = (-shear, shear)
+            else:
+                assert isinstance(shear, (tuple, list)) and \
+                    (len(shear) == 2 or len(shear) == 4), \
+                    "shear should be a list or tuple and it must be of length 2 or 4."
+                # X-Axis shear with [min, max]
+                if len(shear) == 2:
+                    self.shear = [shear[0], shear[1], 0., 0.]
+                elif len(shear) == 4:
+                    self.shear = [s for s in shear]
+        else:
+            self.shear = shear
+
+        self.resample = resample
+        self.fillcolor = fillcolor
+
+    @staticmethod
+    def get_params(degrees, translate, scale_ranges, shears, img_size):
+        """Get parameters for affine transformation
+        Returns:
+            sequence: params to be passed to the affine transformation
+        """
+        angle = random.uniform(degrees[0], degrees[1])
+        if translate is not None:
+            max_dx = translate[0] * img_size[0]
+            max_dy = translate[1] * img_size[1]
+            translations = (np.round(random.uniform(-max_dx, max_dx)),
+                            np.round(random.uniform(-max_dy, max_dy)))
+        else:
+            translations = (0, 0)
+
+        if scale_ranges is not None:
+            scale = random.uniform(scale_ranges[0], scale_ranges[1])
+        else:
+            scale = 1.0
+
+        if shears is not None:
+            if len(shears) == 2:
+                shear = [random.uniform(shears[0], shears[1]), 0.]
+            elif len(shears) == 4:
+                shear = [random.uniform(shears[0], shears[1]),
+                         random.uniform(shears[2], shears[3])]
+        else:
+            shear = 0.0
+
+        return angle, translations, scale, shear
+
+    def __call__(self, img):
+        """
+            img (PIL Image): Image to be transformed.
+        Returns:
+            PIL Image: Affine transformed image.
+        """
+        ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
+        return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor)
+
+    def __repr__(self):
+        s = '{name}(degrees={degrees}'
+        if self.translate is not None:
+            s += ', translate={translate}'
+        if self.scale is not None:
+            s += ', scale={scale}'
+        if self.shear is not None:
+            s += ', shear={shear}'
+        if self.resample > 0:
+            s += ', resample={resample}'
+        if self.fillcolor != 0:
+            s += ', fillcolor={fillcolor}'
+        s += ')'
+        d = dict(self.__dict__)
+        d['resample'] = _pil_interpolation_to_str[d['resample']]
+        return s.format(name=self.__class__.__name__, **d)
+
+
+class Grayscale(object):
+    """Convert image to grayscale.
+    Args:
+        num_output_channels (int): (1 or 3) number of channels desired for output image
+    Returns:
+        PIL Image: Grayscale version of the input.
+        - If num_output_channels == 1 : returned image is single channel
+        - If num_output_channels == 3 : returned image is 3 channel with r == g == b
+    """
+
+    def __init__(self, num_output_channels=1):
+        self.num_output_channels = num_output_channels
+
+    def __call__(self, img):
+        """
+        Args:
+            img (PIL Image): Image to be converted to grayscale.
+        Returns:
+            PIL Image: Randomly grayscaled image.
+        """
+        return F.to_grayscale(img, num_output_channels=self.num_output_channels)
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels)
+
+
+class RandomGrayscale(object):
+    """Randomly convert image to grayscale with a probability of p (default 0.1).
+    Args:
+        p (float): probability that image should be converted to grayscale.
+    Returns:
+        PIL Image: Grayscale version of the input image with probability p and unchanged
+        with probability (1-p).
+        - If input image is 1 channel: grayscale version is 1 channel
+        - If input image is 3 channel: grayscale version is 3 channel with r == g == b
+    """
+
+    def __init__(self, p=0.1):
+        self.p = p
+
+    def __call__(self, img):
+        """
+        Args:
+            img (PIL Image): Image to be converted to grayscale.
+        Returns:
+            PIL Image: Randomly grayscaled image.
+        """
+        num_output_channels = 1 if img.mode == 'L' else 3
+        if random.random() < self.p:
+            return F.to_grayscale(img, num_output_channels=num_output_channels)
+        return img
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(p={0})'.format(self.p)
+
+
+class RandomErasing(object):
+    """ Randomly selects a rectangle region in an image and erases its pixels.
+        'Random Erasing Data Augmentation' by Zhong et al.
+        See https://arxiv.org/pdf/1708.04896.pdf
+    Args:
+         p: probability that the random erasing operation will be performed.
+         scale: range of proportion of erased area against input image.
+         ratio: range of aspect ratio of erased area.
+         value: erasing value. Default is 0. If a single int, it is used to
+            erase all pixels. If a tuple of length 3, it is used to erase
+            R, G, B channels respectively.
+            If a str of 'random', erasing each pixel with random values.
+         inplace: boolean to make this transform inplace. Default set to False.
+    Returns:
+        Erased Image.
+    # Examples:
+        >>> transform = transforms.Compose([
+        >>> transforms.RandomHorizontalFlip(),
+        >>> transforms.ToTensor(),
+        >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+        >>> transforms.RandomErasing(),
+        >>> ])
+    """
+
+    def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
+        assert isinstance(value, (numbers.Number, str, tuple, list))
+        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
+            warnings.warn("range should be of kind (min, max)")
+        if scale[0] < 0 or scale[1] > 1:
+            raise ValueError("range of scale should be between 0 and 1")
+        if p < 0 or p > 1:
+            raise ValueError("range of random erasing probability should be between 0 and 1")
+
+        self.p = p
+        self.scale = scale
+        self.ratio = ratio
+        self.value = value
+        self.inplace = inplace
+
+    @staticmethod
+    def get_params(img, scale, ratio, value=0):
+        """Get parameters for ``erase`` for a random erasing.
+        Args:
+            img (Tensor): Tensor image of size (C, H, W) to be erased.
+            scale: range of proportion of erased area against input image.
+            ratio: range of aspect ratio of erased area.
+        Returns:
+            tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
+        """
+        img_c, img_h, img_w = img.shape
+        area = img_h * img_w
+
+        for attempt in range(10):
+            erase_area = random.uniform(scale[0], scale[1]) * area
+            aspect_ratio = random.uniform(ratio[0], ratio[1])
+
+            h = int(round(math.sqrt(erase_area * aspect_ratio)))
+            w = int(round(math.sqrt(erase_area / aspect_ratio)))
+
+            if h < img_h and w < img_w:
+                i = random.randint(0, img_h - h)
+                j = random.randint(0, img_w - w)
+                if isinstance(value, numbers.Number):
+                    v = value
+                elif isinstance(value, torch._six.string_classes):
+                    v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
+                elif isinstance(value, (list, tuple)):
+                    v = torch.tensor(value, dtype=torch.float32).view(-1, 1, 1).expand(-1, h, w)
+                return i, j, h, w, v
+
+        # Return original image
+        return 0, 0, img_h, img_w, img
+
+    def __call__(self, img):
+        """
+        Args:
+            img (Tensor): Tensor image of size (C, H, W) to be erased.
+        Returns:
+            img (Tensor): Erased Tensor image.
+        """
+        if random.uniform(0, 1) < self.p:
+            x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value)
+            return F.erase(img, x, y, h, w, v, self.inplace)
+        return img
+
+
+class HEDJitter(object):
+    """Randomly perturbe the HED color space value an RGB image.
+    First, it disentangled the hematoxylin and eosin color channels by color deconvolution method using a fixed matrix.
+    Second, it perturbed the hematoxylin, eosin and DAB stains independently.
+    Third, it transformed the resulting stains into regular RGB color space.
+    Args:
+        theta (float): How much to jitter HED color space,
+         alpha is chosen from a uniform distribution [1-theta, 1+theta]
+         betti is chosen from a uniform distribution [-theta, theta]
+         the jitter formula is **s' = \alpha * s + \betti**
+    """
+    def __init__(self, theta=0.): # HED_light: theta=0.05; HED_strong: theta=0.2
+        assert isinstance(theta, numbers.Number), "theta should be a single number."
+        self.theta = theta
+        self.alpha = np.random.uniform(1-theta, 1+theta, (1, 3))
+        self.betti = np.random.uniform(-theta, theta, (1, 3))
+
+    @staticmethod
+    def adjust_HED(img, alpha, betti):
+        img = np.array(img)
+        s = np.reshape(color.rgb2hed(img), (-1, 3))
+        ns = alpha * s + betti  # perturbations on HED color space
+        nimg = color.hed2rgb(np.reshape(ns, img.shape))
+
+        imin = nimg.min()
+        imax = nimg.max()
+        rsimg = (255 * (nimg - imin) / (imax - imin)).astype('uint8')  # rescale to [0,255]
+        # transfer to PIL image
+        return Image.fromarray(rsimg)
+
+    def __call__(self, img):
+        return self.adjust_HED(img, self.alpha, self.betti)
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + '('
+        format_string += 'theta={0}'.format(self.theta)
+        format_string += ',alpha={0}'.format(self.alpha)
+        format_string += ',betti={0}'.format(self.betti)
+        return format_string
+
+
+class AutoRandomRotation(object):
+    """auto randomly select angle 0, 90, 180 or 270 for rotating the image.
+    Args:
+        resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
+            An optional resampling filter. See `filters`_ for more information.
+            If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
+        expand (bool, optional): Optional expansion flag.
+            If true, expands the output to make it large enough to hold the entire rotated image.
+            If false or omitted, make the output image the same size as the input image.
+            Note that the expand flag assumes rotation around the center and no translation.
+        center (2-tuple, optional): Optional center of rotation.
+            Origin is the upper left corner.
+            Default is the center of the image.
+        fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
+            If int, it is used for all channels respectively.
+    """
+
+    def __init__(self, degree=None, resample=False, expand=True, center=None, fill=0):
+        if degree is None:
+            self.degrees = random.choice([0, 90, 180, 270])
+        else:
+            assert degree in [0, 90, 180, 270], 'degree must be in [0, 90, 180, 270]'
+            self.degrees = degree
+
+        self.resample = resample
+        self.expand = expand
+        self.center = center
+        self.fill = fill
+
+    def __call__(self, img):
+        return F.rotate(img, self.degrees, self.resample, self.expand, self.center, self.fill)
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
+        format_string += ', resample={0}'.format(self.resample)
+        format_string += ', expand={0}'.format(self.expand)
+        if self.center is not None:
+            format_string += ', center={0}'.format(self.center)
+        format_string += ')'
+        return format_string
+
+
+class RandomGaussBlur(object):
+    """Random GaussBlurring on image by radius parameter.
+    Args:
+        radius (list, tuple): radius range for selecting from; you'd better set it < 2
+    """
+    def __init__(self, radius=None):
+        if radius is not None:
+            assert isinstance(radius, (tuple, list)) and len(radius) == 2, \
+                "radius should be a list or tuple and it must be of length 2."
+            self.radius = random.uniform(radius[0], radius[1])
+        else:
+            self.radius = 0.0
+
+    def __call__(self, img):
+        return img.filter(ImageFilter.GaussianBlur(radius=self.radius))
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(Gaussian Blur radius={0})'.format(self.radius)
+
+
+class RandomAffineCV2(object):
+    """Random Affine transformation by CV2 method on image by alpha parameter.
+    Args:
+        alpha (float): alpha value for affine transformation
+        mask (PIL Image) in __call__, if not assign, set None.
+    """
+    def __init__(self, alpha):
+        assert isinstance(alpha, numbers.Number), "alpha should be a single number."
+        assert 0. <= alpha <= 0.15, \
+            "In pathological image, alpha should be in (0,0.15), you can change in myTransform.py"
+        self.alpha = alpha
+
+    @staticmethod
+    def affineTransformCV2(img, alpha, mask=None):
+        alpha = img.shape[1] * alpha
+        if mask is not None:
+            mask = np.array(mask).astype(np.uint8)
+            img = np.concatenate((img, mask[..., None]), axis=2)
+
+        imgsize = img.shape[:2]
+        center = np.float32(imgsize) // 2
+        censize = min(imgsize) // 3
+        pts1 = np.float32([center+censize, [center[0]+censize, center[1]-censize], center-censize])  # raw point
+        pts2 = pts1 + np.random.uniform(-alpha, alpha, size=pts1.shape).astype(np.float32)  # output point
+        M = cv2.getAffineTransform(pts1, pts2)  # affine matrix
+        img = cv2.warpAffine(img, M, imgsize[::-1],
+                               flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REFLECT_101)
+        if mask is not None:
+            return Image.fromarray(img[..., :3]), Image.fromarray(img[..., 3])
+        else:
+            return Image.fromarray(img)
+
+    def __call__(self, img, mask=None):
+        return self.affineTransformCV2(np.array(img), self.alpha, mask)
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(alpha value={0})'.format(self.alpha)
+
+
+class RandomElastic(object):
+    """Random Elastic transformation by CV2 method on image by alpha, sigma parameter.
+        # you can refer to:  https://blog.csdn.net/qq_27261889/article/details/80720359
+        # https://blog.csdn.net/maliang_1993/article/details/82020596
+        # https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html#scipy.ndimage.map_coordinates
+    Args:
+        alpha (float): alpha value for Elastic transformation, factor
+        if alpha is 0, output is original whatever the sigma;
+        if alpha is 1, output only depends on sigma parameter;
+        if alpha < 1 or > 1, it zoom in or out the sigma's Relevant dx, dy.
+        sigma (float): sigma value for Elastic transformation, should be \ in (0.05,0.1)
+        mask (PIL Image) in __call__, if not assign, set None.
+    """
+    def __init__(self, alpha, sigma):
+        assert isinstance(alpha, numbers.Number) and isinstance(sigma, numbers.Number), \
+            "alpha and sigma should be a single number."
+        assert 0.05 <= sigma <= 0.1, \
+            "In pathological image, sigma should be in (0.05,0.1)"
+        self.alpha = alpha
+        self.sigma = sigma
+
+    @staticmethod
+    def RandomElasticCV2(img, alpha, sigma, mask=None):
+        alpha = img.shape[1] * alpha
+        sigma = img.shape[1] * sigma
+        if mask is not None:
+            mask = np.array(mask).astype(np.uint8)
+            img = np.concatenate((img, mask[..., None]), axis=2)
+
+        shape = img.shape
+
+        dx = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma) * alpha
+        dy = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma) * alpha
+        # dz = np.zeros_like(dx)
+
+        x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2]))
+        indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1)), np.reshape(z, (-1, 1))
+
+        img = map_coordinates(img, indices, order=0, mode='reflect').reshape(shape)
+        if mask is not None:
+            return Image.fromarray(img[..., :3]), Image.fromarray(img[..., 3])
+        else:
+            return Image.fromarray(img)
+
+    def __call__(self, img, mask=None):
+        return self.RandomElasticCV2(np.array(img), self.alpha, self.sigma, mask)
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + '(alpha value={0})'.format(self.alpha)
+        format_string += ', sigma={0}'.format(self.sigma)
+        format_string += ')'
+        return format_string
\ No newline at end of file
diff --git a/code/models/TransMIL.py b/code/models/TransMIL.py
index 01a0fa7a9ec3ebfa0c8655eb6f68bafa79fc6f01..d75eb5f3ea2633ac67a4fcbe6eab35341cbd88e6 100755
--- a/code/models/TransMIL.py
+++ b/code/models/TransMIL.py
@@ -58,19 +58,45 @@ class PPEG(nn.Module):
 
 
 class TransMIL(nn.Module):
-    def __init__(self, n_classes):
+    def __init__(self, n_classes, in_features, out_features=512):
         super(TransMIL, self).__init__()
-        in_features = 2048
-        inter_features = 1024
-        out_features = 512
+        # in_features = 2048
+        # inter_features = 1024
+        # inter_features_2 = 512
+        # out_features = 1024 
+        # out_features = 512 
         if apex_available: 
             norm_layer = apex.normalization.FusedLayerNorm
         else:
             norm_layer = nn.LayerNorm
 
         self.pos_layer = PPEG(dim=out_features)
-        self._fc1 = nn.Sequential(nn.Linear(in_features, inter_features), nn.GELU(), nn.Dropout(p=0.5), norm_layer(inter_features)) 
-        self._fc1_2 = nn.Sequential(nn.Linear(inter_features, out_features), nn.GELU())
+        # self._fc1 = nn.Sequential(nn.Linear(in_features, int(in_features/2)), nn.GELU(), nn.Dropout(p=0.2), norm_layer(int(in_features/2))) # 2048 -> 1024
+        # self._fc1_1 = nn.Sequential(nn.Linear(int(in_features/2), int(in_features/2)), nn.GELU(), nn.Dropout(p=0.2), norm_layer(int(in_features/2))) # 2048 -> 1024
+        # self._fc1_2 = nn.Sequential(nn.Linear(int(in_features/2), int(in_features/2)), nn.GELU(), nn.Dropout(p=0.2), norm_layer(int(in_features/2))) # 2048 -> 1024
+        # self._fc2 = nn.Sequential(nn.Linear(int(in_features/2), int(in_features/4)), nn.GELU(), nn.Dropout(p=0.2), norm_layer(int(in_features/4))) # 1024 -> 512
+        # self._fc3 = nn.Sequential(nn.Linear(int(in_features/4), out_features), nn.GELU()) # 512 -> 256
+
+
+
+        if in_features == 2048:
+            self._fc1 = nn.Sequential(
+                nn.Linear(in_features, int(in_features/2)), nn.GELU(), nn.Dropout(p=0.6), norm_layer(int(in_features/2)),
+                nn.Linear(int(in_features/2), out_features), nn.GELU(),
+                ) 
+        elif in_features == 1024:
+            self._fc1 = nn.Sequential(
+                # nn.Linear(in_features, int(in_features/2)), nn.GELU(), nn.Dropout(p=0.2), norm_layer(out_features),
+                nn.Linear(in_features, out_features), nn.GELU(), nn.Dropout(p=0.6), norm_layer(out_features)
+                ) 
+        # out_features = 256 
+        # self._fc1 = nn.Sequential(
+        #     nn.Linear(in_features, out_features), nn.GELU(), nn.Dropout(p=0.2), norm_layer(out_features)
+        #     ) 
+        # self._fc1_2 = nn.Sequential(nn.Linear(inter_features, inter_features_2), nn.GELU(), nn.Dropout(p=0.5), norm_layer(inter_features_2)) 
+        # self._fc1_3 = nn.Sequential(nn.Linear(inter_features_2, out_features), nn.GELU())
+        # self._fc1 = nn.Sequential(nn.Linear(in_features, 256), nn.GELU())
+        # self._fc1_2 = nn.Sequential(nn.Linear(int(in_features/2), out_features), nn.GELU())
         # self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU())
         
         self.cls_token = nn.Parameter(torch.randn(1, 1, out_features))
@@ -79,7 +105,7 @@ class TransMIL(nn.Module):
         self.layer2 = TransLayer(norm_layer=norm_layer, dim=out_features)
         # self.norm = nn.LayerNorm(out_features)
         self.norm = norm_layer(out_features)
-        self._fc2 = nn.Linear(out_features, self.n_classes)
+        self._fc = nn.Linear(out_features, self.n_classes)
 
         # self.model_ft = ResNet.resnet50(num_classes=self.n_classes, mlp=False, two_branch=False, normlinear=True).to(self.device)
         # home = Path.cwd().parts[1]
@@ -94,11 +120,15 @@ class TransMIL(nn.Module):
     def forward(self, x): #, **kwargs
 
         # x = self.model_ft(x).unsqueeze(0)
-        h = x.squeeze(0).float() #[B, n, 1024]
+        if x.dim() > 3:
+            x = x.squeeze(0)
+        h = x.float() #[B, n, 1024]
         h = self._fc1(h) #[B, n, 512]
         # h = self.drop(h)
-        h = self._fc1_2(h) #[B, n, 512]
-        
+        # h = self._fc1_1(h) #[B, n, 512]
+        # h = self._fc1_2(h) #[B, n, 512]
+        # h = self._fc2(h) #[B, n, 512]
+        # h = self._fc3(h) #[B, n, 512]
         # print('Feature Representation: ', h.shape)
         #---->duplicate pad
         H = h.shape[1]
@@ -135,7 +165,7 @@ class TransMIL(nn.Module):
         h = self.norm(h)[:,0]
 
         #---->predict
-        logits = self._fc2(h) #[B, n_classes]
+        logits = self._fc(h) #[B, n_classes]
         # Y_hat = torch.argmax(logits, dim=1)
         # Y_prob = F.softmax(logits, dim = 1)
         # results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat}
diff --git a/code/models/TransformerMIL.py b/code/models/TransformerMIL.py
index d5e6b89fa4088afc1b143362b627f33de7d57ce6..c5b595134337c9a2ea168bfc18c82b7e9ee122fc 100644
--- a/code/models/TransformerMIL.py
+++ b/code/models/TransformerMIL.py
@@ -4,6 +4,13 @@ import torch.nn.functional as F
 import numpy as np
 from nystrom_attention import NystromAttention
 
+try:
+    import apex
+    apex_available=True
+except ModuleNotFoundError:
+    # Error handling
+    apex_available = False
+    pass
 
 class TransLayer(nn.Module):
 
@@ -46,12 +53,25 @@ class PPEG(nn.Module):
 
 
 class TransformerMIL(nn.Module):
-    def __init__(self, n_classes):
+    def __init__(self, n_classes, in_features, out_features=512):
         super(TransformerMIL, self).__init__()
-        in_features = 1024
-        out_features = 512
+        # in_features = 2048
+        # out_features = 512
         # self.pos_layer = PPEG(dim=out_features)
-        self._fc1 = nn.Sequential(nn.Linear(in_features, out_features), nn.GELU())
+        if apex_available: 
+            norm_layer = apex.normalization.FusedLayerNorm
+        else:
+            norm_layer = nn.LayerNorm
+        if in_features == 2048:
+            self._fc1 = nn.Sequential(
+                nn.Linear(in_features, int(in_features/2)), nn.GELU(), nn.Dropout(p=0.6), norm_layer(int(in_features/2)),
+                nn.Linear(int(in_features/2), out_features), nn.GELU(),
+                ) 
+        elif in_features == 1024:
+            self._fc1 = nn.Sequential(
+                # nn.Linear(in_features, int(in_features/2)), nn.GELU(), nn.Dropout(p=0.2), norm_layer(out_features),
+                nn.Linear(in_features, out_features), nn.GELU(), nn.Dropout(p=0.6), norm_layer(out_features)
+                ) 
         # self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU())
         self.cls_token = nn.Parameter(torch.randn(1, 1, out_features))
         self.n_classes = n_classes
@@ -63,7 +83,7 @@ class TransformerMIL(nn.Module):
 
     def forward(self, x): #, **kwargs
 
-        h = x.float() #[B, n, 1024]
+        h = x.squeeze(0).float() #[B, n, 1024]
         h = self._fc1(h) #[B, n, 512]
         
         # print('Feature Representation: ', h.shape)
@@ -102,7 +122,7 @@ class TransformerMIL(nn.Module):
         #---->predict
         logits = self._fc2(h) #[B, n_classes]
         # return logits, attn2
-        return logits, attn1
+        return logits
 
 if __name__ == "__main__":
     data = torch.randn((1, 6000, 512)).cuda()
diff --git a/code/models/__pycache__/TransMIL.cpython-39.pyc b/code/models/__pycache__/TransMIL.cpython-39.pyc
index 21d7707719795a1335b1187bffed7fd0328ba3fb..0e5f4fc107123caf151ac1282261571966a5a167 100644
Binary files a/code/models/__pycache__/TransMIL.cpython-39.pyc and b/code/models/__pycache__/TransMIL.cpython-39.pyc differ
diff --git a/code/models/__pycache__/TransformerMIL.cpython-39.pyc b/code/models/__pycache__/TransformerMIL.cpython-39.pyc
index f2c5bda0715cd2d8601f4a33e2187a553a3f322a..8b1abd15afbfda834b941d0a5e2eff2d1aa7773d 100644
Binary files a/code/models/__pycache__/TransformerMIL.cpython-39.pyc and b/code/models/__pycache__/TransformerMIL.cpython-39.pyc differ
diff --git a/code/models/__pycache__/model_interface.cpython-39.pyc b/code/models/__pycache__/model_interface.cpython-39.pyc
index 0466ab737a0a0da0cfa6b04ea69f2aef82561a6d..3da7b2035d7f1297664e6c70bc48be9b1643e010 100644
Binary files a/code/models/__pycache__/model_interface.cpython-39.pyc and b/code/models/__pycache__/model_interface.cpython-39.pyc differ
diff --git a/code/models/__pycache__/model_interface_classic.cpython-39.pyc b/code/models/__pycache__/model_interface_classic.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..97044fe05269d43ef2c97e75a0d3b37f58673a23
Binary files /dev/null and b/code/models/__pycache__/model_interface_classic.cpython-39.pyc differ
diff --git a/code/models/model_interface.py b/code/models/model_interface.py
index 0186c069c9f62c945c40647906823ed424deb06f..3d37368d9aa6763ec01ae77fbcdf304ebb6c89da 100755
--- a/code/models/model_interface.py
+++ b/code/models/model_interface.py
@@ -8,6 +8,8 @@ import pandas as pd
 import seaborn as sns
 from pathlib import Path
 from matplotlib import pyplot as plt
+plt.style.use('tableau-colorblind10')
+import pandas as pd
 import cv2
 from PIL import Image
 from pytorch_pretrained_vit import ViT
@@ -28,8 +30,9 @@ import torch.nn as nn
 import torch.nn.functional as F
 import torchmetrics
 from torchmetrics.functional import stat_scores
+from torchmetrics.functional.classification import binary_auroc, multiclass_auroc, binary_precision_recall_curve, multiclass_precision_recall_curve
 from torch import optim as optim
-from torch.optim.lr_scheduler import ReduceLROnPlateau
+from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
 
 from monai.config import KeysCollection
 from monai.data import Dataset, load_decathlon_datalist
@@ -37,6 +40,7 @@ from monai.data.wsi_reader import WSIReader
 from monai.metrics import Cumulative, CumulativeAverage
 from monai.networks.nets import milmodel
 
+
 # from sklearn.metrics import roc_curve, auc, roc_curve_score
 
 
@@ -81,24 +85,27 @@ class ModelInterface(pl.LightningModule):
         if model.name == 'AttTrans':
             self.model = milmodel.MILModel(num_classes=self.n_classes, pretrained=True, mil_mode='att_trans', backbone_num_features=1024)
         else: self.load_model()
-        # self.loss = create_loss(loss, model.n_classes)
-        # self.loss = 
         if self.n_classes>2:
-            self.aucm_loss = AUCM_MultiLabel(num_classes = model.n_classes, device=self.device)
+            # self.aucm_loss = AUCM_MultiLabel(num_classes = self.n_classes, device=self.device)
+            # self.loss = LabelSmoothingCrossEntropy(smoothing=0.1)
+            self.loss = create_loss(loss, model.n_classes)
         else:
-            self.aucm_loss = AUCMLoss()
+            # self.loss = CompositionalAUCLoss()
+            self.loss = create_loss(loss, model.n_classes)
         # self.asl = AsymmetricLossSingleLabel()
-        self.loss = LabelSmoothingCrossEntropy(smoothing=0.1)
+        self.lsce_loss = LabelSmoothingCrossEntropy(smoothing=0.2)
 
-        # self.loss = 
-        # print(self.model)
         self.model_name = model.name
         
         
-        # self.ecam = EigenGradCAM(model = self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform)
         self.optimizer = optimizer
         
         self.save_path = kargs['log']
+        
+        # self.in_features = kargs['in_features']
+        # self.out_features = kargs['out_features']
+        self.in_features = 2048
+        self.out_features = 512
         if Path(self.save_path).parts[3] == 'tcmr':
             temp = list(Path(self.save_path).parts)
             # print(temp)
@@ -112,40 +119,38 @@ class ModelInterface(pl.LightningModule):
 
         #---->acc
         self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+        self.data_patient = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
         # print(self.experiment)
         #---->Metrics
         if self.n_classes > 2: 
-            self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average='macro')
-            
-            metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes,
+            self.AUROC = torchmetrics.AUROC(task='multiclass', num_classes = self.n_classes, average='weighted')
+            self.PRC = torchmetrics.PrecisionRecallCurve(task='multiclass', num_classes = self.n_classes)
+            self.ROC = torchmetrics.ROC(task='multiclass', num_classes=self.n_classes)
+            self.confusion_matrix = torchmetrics.ConfusionMatrix(task='multiclass', num_classes = self.n_classes) 
+            metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(task='multiclass', num_classes = self.n_classes,
                                                                            average='weighted'),
-                                                     torchmetrics.CohenKappa(num_classes = self.n_classes),
-                                                     torchmetrics.F1Score(num_classes = self.n_classes,
+                                                     torchmetrics.CohenKappa(task='multiclass', num_classes = self.n_classes),
+                                                     torchmetrics.F1Score(task='multiclass', num_classes = self.n_classes,
                                                                      average = 'macro'),
-                                                     torchmetrics.Recall(average = 'macro',
+                                                     torchmetrics.Recall(task='multiclass', average = 'macro',
                                                                          num_classes = self.n_classes),
-                                                     torchmetrics.Precision(average = 'macro',
+                                                     torchmetrics.Precision(task='multiclass', average = 'macro',
                                                                             num_classes = self.n_classes),
-                                                     torchmetrics.Specificity(average = 'macro',
+                                                     torchmetrics.Specificity(task='multiclass', average = 'macro',
                                                                             num_classes = self.n_classes)])
                                                                             
         else : 
-            self.AUROC = torchmetrics.AUROC(num_classes=self.n_classes, average='weighted')
+            self.AUROC = torchmetrics.AUROC(task='binary')
             # self.AUROC = torchmetrics.AUROC(num_classes=self.n_classes, average = 'weighted')
-
-            metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2,
-                                                                           average = 'weighted'),
-                                                     torchmetrics.CohenKappa(num_classes = 2),
-                                                     torchmetrics.F1Score(num_classes = 2,
-                                                                     average = 'macro'),
-                                                     torchmetrics.Recall(average = 'macro',
-                                                                         num_classes = 2),
-                                                     torchmetrics.Precision(average = 'macro',
-                                                                            num_classes = 2)])
-        self.PRC = torchmetrics.PrecisionRecallCurve(num_classes = self.n_classes)
-        self.ROC = torchmetrics.ROC(num_classes=self.n_classes)
-        # self.pr_curve = torchmetrics.BinnedPrecisionRecallCurve(num_classes = self.n_classes, thresholds=10)
-        self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes)                                                                    
+            self.PRC = torchmetrics.PrecisionRecallCurve(task='binary')
+            self.ROC = torchmetrics.ROC(task='binary')
+            metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(task='binary'),
+                                                     torchmetrics.CohenKappa(task='binary'),
+                                                     torchmetrics.F1Score(task='binary'),
+                                                     torchmetrics.Recall(task='binary'),
+                                                     torchmetrics.Precision(task='binary')
+                                                     ])
+            self.confusion_matrix = torchmetrics.ConfusionMatrix(task='binary')                                                                    
         self.valid_metrics = metrics.clone(prefix = 'val_')
         self.valid_patient_metrics = metrics.clone(prefix = 'val_patient_')
         self.test_metrics = metrics.clone(prefix = 'test_')
@@ -156,13 +161,23 @@ class ModelInterface(pl.LightningModule):
         self.count = 0
         self.backbone = kargs['backbone']
 
-        self.out_features = 1024
 
         if self.backbone == 'features':
             self.model_ft = None
+            
         elif self.backbone == 'dino':
             self.feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
             self.model_ft = ViTModel.from_pretrained('facebook/dino-vitb16')
+        # elif self.backbone == 'inception':
+        #     self.model_ft = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True)
+        #     self.model_ft.aux_logits = False
+        #     for parameter in self.model_ft.parameters():
+        #         parameter.requires_grad = False
+
+        #     self.model_ft.fc = nn.Sequential(nn.Linear(model.fc.in_features, 10),
+        #                                     nn.Linear(10, self)
+        #     )
+
         elif self.backbone == 'resnet18':
             self.model_ft = models.resnet18(weights='IMAGENET1K_V1')
             # modules = list(resnet18.children())[:-1]
@@ -188,7 +203,7 @@ class ModelInterface(pl.LightningModule):
             # )
         elif self.backbone == 'retccl':
             # import models.ResNet as ResNet
-            self.model_ft = ResNet.resnet50(num_classes=self.n_classes, mlp=False, two_branch=False, normlinear=True)
+            self.model_ft = ResNet.resnet50(num_classes=128, mlp=False, two_branch=False, normlinear=True)
             home = Path.cwd().parts[1]
             # pre_model = 
             # self.model_ft.fc = nn.Identity()
@@ -196,8 +211,8 @@ class ModelInterface(pl.LightningModule):
             self.model_ft.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=False)
             for param in self.model_ft.parameters():
                 param.requires_grad = False
-            self.model_ft.fc = nn.Linear(2048, self.out_features)
-            
+            self.model_ft.fc = torch.nn.Identity()
+            # self.model_ft.eval()
             # self.model_ft = FeatureExtractor('retccl', self.n_classes)
 
 
@@ -207,25 +222,6 @@ class ModelInterface(pl.LightningModule):
             for param in self.model_ft.parameters():
                 param.requires_grad = False
 
-            # self.model_ft = models.resnet50(pretrained=True)
-            # for param in self.model_ft.parameters():
-            #     param.requires_grad = False
-            # self.model_ft.fc = nn.Linear(2048, self.out_features)
-
-
-            # modules = list(resnet50.children())[:-3]
-            # res50 = nn.Sequential(
-            #     *modules,     
-            # )
-            
-            # self.model_ft = nn.Sequential(
-            #     res50,
-            #     nn.AdaptiveAvgPool2d(1),
-            #     View((-1, 1024)),
-            #     nn.Linear(1024, self.out_features),
-            #     # nn.GELU()
-            # )
-        # elif kargs
             
         elif self.backbone == 'efficientnet':
             efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_widese_b0', pretrained=True)
@@ -250,6 +246,10 @@ class ModelInterface(pl.LightningModule):
                 nn.Linear(53*53, self.out_features),
                 nn.ReLU(),
             )
+        if self.model_ft:
+            self.example_input_array = torch.rand([1,1,3,224,224])
+        else:
+            self.example_input_array = torch.rand([1,1000,self.in_features])
         # print(self.model_ft[0].features[-1])
         # print(self.model_ft)
 
@@ -260,17 +260,26 @@ class ModelInterface(pl.LightningModule):
         if self.model_name == 'AttTrans':
             return self.model(x)
         if self.model_ft:
-            x = x.squeeze(0)
+            # x = x.squeeze(0)
+            # if x.dim() == 5:
+            batch_size = x.shape[0]
+            bag_size = x.shape[1]
+            x = x.view(batch_size*bag_size, x.shape[2], x.shape[3], x.shape[4])
             feats = self.model_ft(x).unsqueeze(0)
+            # print(feats.shape)
+            # print(x.shape)
+            # if feats.dim() == 3:
+            feats = feats.view(batch_size, bag_size, -1)
         else: 
             feats = x.unsqueeze(0)
-        
+        del x
         return self.model(feats)
         # return self.model(x)
 
     def step(self, input):
 
         input = input.float()
+        # print(input.shape)
         # logits, _ = self(input.contiguous()) 
         logits = self(input.contiguous())
         Y_hat = torch.argmax(logits, dim=1)
@@ -282,45 +291,36 @@ class ModelInterface(pl.LightningModule):
 
         return logits, Y_prob, Y_hat
 
-    def training_step(self, batch, batch_idx):
+    def training_step(self, batch):
 
         input, label, _= batch
 
-        #random image dropout
 
-        # bag_size = input.squeeze().shape[0] * 0.7
-        # bag_idxs = torch.randperm(input.squeeze(0).shape[0])[:bag_size]
-        # input = input.squeeze(0)[bag_idxs].unsqueeze(0)
-
-        # label = label.float()
-        
         logits, Y_prob, Y_hat = self.step(input) 
 
         #---->loss
-        loss = self.loss(logits, label)
+        # loss = self.loss(logits, label)
 
         one_hot_label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
-        # aucm_loss = self.aucm_loss(torch.sigmoid(logits), one_hot_label)
-        # total_loss = torch.mean(loss + aucm_loss)
-        Y = int(label)
-        # print(logits, label)
-        # loss = cross_entropy_torch(logits.squeeze(0), label)
-        # loss = self.asl(logits, label.squeeze())
-
-        #---->acc log
-        # print(label)
-        # Y_hat = int(Y_hat)
-        # if self.n_classes == 2:
-        #     Y = int(label[0][1])
-        # else: 
-        # Y = torch.argmax(label)
+        loss = self.loss(logits, one_hot_label.float())
+        if loss.ndim == 0:
+            loss = loss.unsqueeze(0)
+        # if self.n_classes > 2: 
+        #     aucm_loss = loss
         
-            # Y = int(label[0])
-        self.data[Y]["count"] += 1
-        self.data[Y]["correct"] += (int(Y_hat) == Y)
+
+        # total_loss = (aucm_loss + loss)/2
+        for y, y_hat in zip(label, Y_hat):
+            
+            y = int(y)
+            # print(Y_hat)
+            self.data[y]["count"] += 1
+            self.data[y]["correct"] += (int(y_hat) == y)
+
+
         # self.log('total_loss', total_loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1, sync_dist=True)
-        # self.log('aucm_loss', aucm_loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1, sync_dist=True)
-        self.log('lsce_loss', loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1, sync_dist=True)
+        self.log('loss', loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1, sync_dist=True)
+        # self.log('lsce_loss', loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1, sync_dist=True)
 
         # if self.current_epoch % 10 == 0:
 
@@ -339,11 +339,23 @@ class ModelInterface(pl.LightningModule):
         return {'loss': loss, 'Y_prob': Y_prob, 'Y_hat': Y_hat, 'label': label} 
 
     def training_epoch_end(self, training_step_outputs):
-        # logits = torch.cat([x['logits'] for x in training_step_outputs], dim = 0)
+
+        # for t in training_step_outputs:
+        # probs = torch.cat([torch.cat(x[0]['Y_prob'], x[1]['Y_prob']) for x in training_step_outputs])
+        # max_probs = torch.stack([torch.stack(x[0]['Y_hat'], x[1]['Y_hat']) for x in training_step_outputs])
+        # target = torch.stack([torch.stack(x[0]['label'], x[1]['label']) for x in training_step_outputs])
+            # print(t)
+
         probs = torch.cat([x['Y_prob'] for x in training_step_outputs])
-        max_probs = torch.stack([x['Y_hat'] for x in training_step_outputs])
-        # target = torch.stack([x['label'] for x in training_step_outputs], dim = 0)
-        target = torch.stack([x['label'] for x in training_step_outputs])
+        max_probs = torch.cat([x['Y_hat'] for x in training_step_outputs])
+        # print(max_probs)
+        target = torch.cat([x['label'] for x in training_step_outputs], dim=0).int()
+
+        # logits = torch.cat([x['logits'] for x in training_step_outputs], dim = 0)
+        # probs = torch.cat([x['Y_prob'] for x in training_step_outputs])
+        # max_probs = torch.stack([x['Y_hat'] for x in training_step_outputs])
+        # # target = torch.stack([x['label'] for x in training_step_outputs], dim = 0)
+        # target = torch.stack([x['label'] for x in training_step_outputs])
         # target = torch.argmax(target, dim=1)
 
         if self.current_epoch % 5 == 0:
@@ -361,31 +373,49 @@ class ModelInterface(pl.LightningModule):
         # print('probs: ', probs)
         if self.current_epoch % 10 == 0:
             self.log_confusion_matrix(max_probs, target, stage='train')
+        if self.n_classes <=2:
+            out_probs = probs[:,1] 
+        else: out_probs = probs
 
-        self.log('Train/auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+        self.log('Train/auc', self.AUROC(out_probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
 
     def validation_step(self, batch, batch_idx):
 
-        input, label, (wsi_name, batch_names, patient) = batch
+        input, label, (wsi_name, patient) = batch
         # label = label.float()
         
         logits, Y_prob, Y_hat = self.step(input) 
+        logits = logits.detach()
+        Y_prob = Y_prob.detach()
+        Y_hat = Y_hat.detach()
 
         #---->acc log
         # Y = int(label[0][1])
         # Y = torch.argmax(label)
-        loss = self.loss(logits, label)
+        loss = self.lsce_loss(logits, label)
+        # loss = cross_entropy_torch(logits, label)
+        # one_hot_label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+        # print(logits)
+        # print(label)
+        # print(one_hot_label)
+        # aucm_loss = self.aucm_loss(logits, one_hot_label.float())
+        # if aucm_loss.ndim == 0:
+        #     aucm_loss = aucm_loss.unsqueeze(0)
+        # print(aucm_loss)
         # loss = self.loss(logits, label)
+        # total_loss = (aucm_loss + loss)/2
         # print(loss)
-        Y = int(label)
 
-        # print(Y_hat)
-        self.data[Y]["count"] += 1
-        self.data[Y]["correct"] += (int(Y_hat) == Y)
+        for y, y_hat in zip(label, Y_hat):
+            y = int(y)
+            # print(Y_hat)
+            self.data[y]["count"] += 1
+            self.data[y]["correct"] += (int(y_hat) == y)
         
         # self.data[Y]["correct"] += (Y_hat.item() == Y)
-
-        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': wsi_name, 'patient': patient, 'loss':loss}
+        self.log('val_loss', loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1, sync_dist=True)
+        # self.log('val_aucm_loss', aucm_loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1, sync_dist=True)
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label.int(), 'name': wsi_name, 'patient': patient, 'loss':loss}
 
 
     def validation_epoch_end(self, val_step_outputs):
@@ -396,18 +426,23 @@ class ModelInterface(pl.LightningModule):
         
         logits = torch.cat([x['logits'] for x in val_step_outputs], dim = 0)
         probs = torch.cat([x['Y_prob'] for x in val_step_outputs])
-        max_probs = torch.stack([x['Y_hat'] for x in val_step_outputs])
-        target = torch.stack([x['label'] for x in val_step_outputs], dim=0).int()
+        max_probs = torch.cat([x['Y_hat'] for x in val_step_outputs])
+        # print(max_probs)
+        target = torch.cat([x['label'] for x in val_step_outputs])
         slide_names = [x['name'] for x in val_step_outputs]
         patients = [x['patient'] for x in val_step_outputs]
 
         loss = torch.stack([x['loss'] for x in val_step_outputs])
+        
+        # print(loss)
+        # print(loss.mean())
+        # print(loss.shape)
         # loss = torch.cat([x['loss'] for x in val_step_outputs])
         # print(loss.shape)
         
 
         # self.log('val_loss', cross_entropy_torch(logits.squeeze(), target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
-        self.log('val_loss', loss, prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+        # self.log('val_loss', loss.mean(), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
         
         # print(logits)
         # print(target)
@@ -415,8 +450,16 @@ class ModelInterface(pl.LightningModule):
                           on_epoch = True, logger = True, sync_dist=True)
         
 
+        if self.n_classes <=2:
+            out_probs = probs[:,1] 
+        else: out_probs = probs
+
+        bin_auroc = binary_auroc(out_probs, target.squeeze())
+        # print('val_bin_auroc: ', bin_auroc)
+
+        # print(target.unique())
         if len(target.unique()) != 1:
-            self.log('val_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+            self.log('val_auc', self.AUROC(out_probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
             # self.log('val_patient_auc', self.AUROC(patient_score, patient_target), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
         else:    
             self.log('val_auc', 0.0, prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
@@ -435,55 +478,77 @@ class ModelInterface(pl.LightningModule):
         patient_list = []            
         patient_score = []      
         patient_target = []
+        patient_class_score = 0
 
         for p, s, pr, t in zip(patients, slide_names, probs, target):
+            p = p[0]
+            # print(s[0])
+            # print(pr)
             if p not in complete_patient_dict.keys():
-                complete_patient_dict[p] = [(s, pr)]
+                complete_patient_dict[p] = {'scores':[(s[0], pr)], 'patient_score': 0}
+                # print((s,pr))
+                # complete_patient_dict[p]['scores'] = []
+                # print(t)
                 patient_target.append(t)
             else:
-                complete_patient_dict[p].append((s, pr))
+                complete_patient_dict[p]['scores'].append((s[0], pr))
 
-       
+        # print(complete_patient_dict)
 
         for p in complete_patient_dict.keys():
+            # complete_patient_dict[p] = 0
             score = []
-            for (slide, probs) in complete_patient_dict[p]:
-                # max_probs = torch.argmax(probs)
-                # if self.n_classes == 2:
-                #     score.append(max_probs)
-                # else: score.append(probs)
+            for (slide, probs) in complete_patient_dict[p]['scores']:
                 score.append(probs)
-
-            # if self.n_classes == 2:
-                # score =
-            score = torch.mean(torch.stack(score), dim=0) #.cpu().detach().numpy()
-            # complete_patient_dict[p]['score'] = score
-            # print(p, score)
-            # patient_list.append(p)    
-            patient_score.append(score)    
+            # print(score)
+            score = torch.stack(score)
+            # print(score)
+            if self.n_classes == 2:
+                positive_positions = (score.argmax(dim=1) == 1).nonzero().squeeze()
+                # print(positive_positions)
+                if positive_positions.numel() != 0:
+                    score = score[positive_positions]
+            if len(score.shape) > 1:
+                score = torch.mean(score, dim=0) #.cpu().detach().numpy()
+
+            patient_score.append(score)  
+            complete_patient_dict[p]['patient_score'] = score
+        correct_patients = []
+        false_patients = []
+
+        for patient, label in zip(complete_patient_dict.keys(), patient_target):
+            if label == 0:
+                p_score =  complete_patient_dict[patient]['patient_score']
+                # print(torch.argmax(patient_score))
+                if torch.argmax(p_score) == label:
+                    correct_patients.append(patient)
+                else: 
+                    false_patients.append(patient)
 
         patient_score = torch.stack(patient_score)
-        # print(patient_target)
-        # print(torch.cat(patient_target))
-        # print(self.AUROC(patient_score.squeeze(), torch.cat(patient_target)))
-
         
-        patient_target = torch.cat(patient_target)
-
+        if self.n_classes <=2:
+            patient_score = patient_score[:,1] 
+        patient_target = torch.stack(patient_target)
+        # print(patient_target)
+        # patient_target = torch.cat(patient_target)
+        # self.log_confusion_matrix(max_probs, target, stage='test', comment='patient')
         # print(patient_score.shape)
         # print(patient_target.shape)
-        
+        if len(patient_target.shape) >1:
+            patient_target = patient_target.squeeze()
+        self.log_roc_curve(patient_score, patient_target, stage='val')
+        # self.log_roc_curve(patient_score, patient_target.squeeze(), stage='test')
+
+        # if self.current_epoch < 20:
+        #     self.log('val_patient_auc', 0.0, prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
         if len(patient_target.unique()) != 1:
-            self.log('val_patient_auc', self.AUROC(patient_score.squeeze(), patient_target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+            self.log('val_patient_auc', self.AUROC(patient_score, patient_target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
         else:    
             self.log('val_patient_auc', 0.0, prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
         
-        self.log_dict(self.valid_patient_metrics(patient_score, patient_target),
+        self.log_dict(self.test_patient_metrics(patient_score, patient_target),
                           on_epoch = True, logger = True, sync_dist=True)
-        
-            
-
-        # precision, recall, thresholds = self.PRC(probs, target)
 
         
 
@@ -507,36 +572,51 @@ class ModelInterface(pl.LightningModule):
 
     def test_step(self, batch, batch_idx):
 
-        input, label, (wsi_name, batch_names, patient) = batch
-        label = label.float()
-        
+        input, label, (wsi_name, patient) = batch
+        # input, label, (wsi_name, batch_names, patient) = batch
+        # label = label.float()
+        # 
         logits, Y_prob, Y_hat = self.step(input) 
-
+        loss = self.lsce_loss(logits, label)
         #---->acc log
-        Y = int(label)
+        # Y = int(label)
+        for y, y_hat in zip(label, Y_hat):
+            
+            y = int(y)
+            # print(Y_hat)
+            self.data[y]["count"] += 1
+            self.data[y]["correct"] += (int(y_hat) == y)
+
         # Y = torch.argmax(label)
 
-        # print(Y_hat)
-        self.data[Y]["count"] += 1
-        self.data[Y]["correct"] += (int(Y_hat) == Y)
+        # # print(Y_hat)
+        # self.data[Y]["count"] += 1
+        # self.data[Y]["correct"] += (int(Y_hat) == Y)
         # self.data[Y]["correct"] += (Y_hat.item() == Y)
 
-        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': wsi_name, 'patient': patient}
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label.int(), 'loss': loss, 'name': wsi_name, 'patient': patient}
 
     def test_epoch_end(self, output_results):
         logits = torch.cat([x['logits'] for x in output_results], dim = 0)
         probs = torch.cat([x['Y_prob'] for x in output_results])
-        max_probs = torch.stack([x['Y_hat'] for x in output_results])
-        target = torch.stack([x['label'] for x in output_results]).int()
+        # max_probs = torch.stack([x['Y_hat'] for x in output_results])
+        max_probs = torch.cat([x['Y_hat'] for x in output_results])
+        target = torch.cat([x['label'] for x in output_results])
         slide_names = [x['name'] for x in output_results]
         patients = [x['patient'] for x in output_results]
+        loss = torch.stack([x['loss'] for x in output_results])
         
         self.log_dict(self.test_metrics(max_probs.squeeze(), target.squeeze()),
                           on_epoch = True, logger = True, sync_dist=True)
-        self.log('test_loss', cross_entropy_torch(logits.squeeze(), target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+        self.log('test_loss', loss.mean(), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+
+        if self.n_classes <=2:
+            out_probs = probs[:,1] 
+        else: out_probs = probs
+            # max_probs = max_probs[:,1]
 
         if len(target.unique()) != 1:
-            self.log('test_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+                self.log('test_auc', self.AUROC(out_probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
             # self.log('val_patient_auc', self.AUROC(patient_score, patient_target), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
         else:    
             self.log('test_auc', 0.0, prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
@@ -554,63 +634,91 @@ class ModelInterface(pl.LightningModule):
         patient_class_score = 0
 
         for p, s, pr, t in zip(patients, slide_names, probs, target):
+            p = p[0]
+            # print(s[0])
+            # print(pr)
             if p not in complete_patient_dict.keys():
-                complete_patient_dict[p] = [(s, pr)]
+                complete_patient_dict[p] = {'scores':[(s[0], pr)], 'patient_score': 0}
+                # print((s,pr))
+                # complete_patient_dict[p]['scores'] = []
+                # print(t)
                 patient_target.append(t)
             else:
-                complete_patient_dict[p].append((s, pr))
+                complete_patient_dict[p]['scores'].append((s[0], pr))
 
-       
+        # print(complete_patient_dict)
 
         for p in complete_patient_dict.keys():
+            # complete_patient_dict[p] = 0
             score = []
-            for (slide, probs) in complete_patient_dict[p]:
-                # if self.n_classes == 2:
-                #     if probs.argmax().item() == 1: # only if binary and if class 1 is more important!!! Normal vs Diseased or Rejection vs Other
-                #         score.append(probs)
-                    
-                # else: 
+            for (slide, probs) in complete_patient_dict[p]['scores']:
                 score.append(probs)
             # print(score)
             score = torch.stack(score)
             # print(score)
             if self.n_classes == 2:
                 positive_positions = (score.argmax(dim=1) == 1).nonzero().squeeze()
+                # print(positive_positions)
                 if positive_positions.numel() != 0:
                     score = score[positive_positions]
-            else:
-            # score = torch.stack(torch.score)
-            ## get scores that predict class 1:
-            # positive_scores = score.argmax(dim=1)
-            # score = torch.sum(score.argmax(dim=1))
-
-            # if score.item() == 1:
-            #     patient_class_score = 1
-                score = torch.mean(score) #.cpu().detach().numpy()
-            # complete_patient_dict[p]['score'] = score
-            # print(p, score)
-            # patient_list.append(p)    
-            patient_score.append(score)    
-
-        print(patient_score)
+            if len(score.shape) > 1:
+                score = torch.mean(score, dim=0) #.cpu().detach().numpy()
+
+            patient_score.append(score)  
+            complete_patient_dict[p]['patient_score'] = score
+        correct_patients = []
+        false_patients = []
+
+        for patient, label in zip(complete_patient_dict.keys(), patient_target):
+            if label == 0:
+                p_score =  complete_patient_dict[patient]['patient_score']
+                # print(torch.argmax(patient_score))
+                if torch.argmax(p_score) == label:
+                    correct_patients.append(patient)
+                else: 
+                    false_patients.append(patient)
+        # print('Label 0:')
+        # print('Correct Patients: ')
+        # print(correct_patients)
+        # print('False Patients: ')
+        # print(false_patients)
+
+        # print('True positive slides: ')
+        # for p in correct_patients: 
+        #     print(complete_patient_dict[p]['scores'])
+        
+        # print('False Negative Slides')
+        # for p in false_patients: 
+        #     print(complete_patient_dict[p]['scores'])
+        
+        
 
         patient_score = torch.stack(patient_score)
-        # patient_target = torch.stack(patient_target)
-        patient_target = torch.cat(patient_target)
+        
+        # complete_patient_dict[p]['patient_score'] = patient_score
+
+        # print(patient_score)
+        if self.n_classes <=2:
+            patient_score = patient_score[:,1] 
+        patient_target = torch.stack(patient_target)
+        # print(patient_target)
+        # patient_target = torch.cat(patient_target)
+        # self.log_confusion_matrix(max_probs, target, stage='test', comment='patient')
+        self.log_roc_curve(patient_score, patient_target.squeeze(), stage='test')
 
         
         if len(patient_target.unique()) != 1:
-            self.log('test_patient_auc', self.AUROC(patient_score.squeeze(), patient_target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+            self.log('test_patient_auc', self.AUROC(patient_score, patient_target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
         else:    
             self.log('test_patient_auc', 0.0, prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
         
         self.log_dict(self.test_patient_metrics(patient_score, patient_target),
                           on_epoch = True, logger = True, sync_dist=True)
         
-            
-
-        # precision, recall, thresholds = self.PRC(probs, target)
-
+        
+        self.log_pr_curve(patient_score, patient_target.squeeze(), stage='test')
+        
+        
         
 
         #---->acc log
@@ -624,6 +732,10 @@ class ModelInterface(pl.LightningModule):
             print('test class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
         self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
         
+
+
+
+
         #---->random, if shuffle data, change seed
         if self.shuffle == True:
             self.count = self.count+1
@@ -631,12 +743,22 @@ class ModelInterface(pl.LightningModule):
 
     def configure_optimizers(self):
         # optimizer_ft = optim.Adam(self.model_ft.parameters(), lr=self.optimizer.lr*0.1)
-        optimizer = create_optimizer(self.optimizer, self.model)
-        # optimizer = PESG(self.model, loss_fn=self.aucm_loss, lr=self.optimizer.lr, margin=1.0, epoch_decay=2e-3, weight_decay=1e-5, device=self.device)
+        # optimizer = create_optimizer(self.optimizer, self.model)
+        if self.n_classes > 2:
+            # optimizer = PESG(self.model, loss_fn=self.aucm_loss, lr=self.optimizer.lr, margin=1.0, epoch_decay=2e-3, weight_decay=1e-5, device=self.device)
+            optimizer = create_optimizer(self.optimizer, self.model)
+        else:
+            # optimizer = PDSCA(self.model, loss_fn=self.loss, lr=0.005, margin=1.0, epoch_decay=2e-3, weight_decay=1e-4, beta0=0.9, beta1=0.9, device=self.device)
+            optimizer = create_optimizer(self.optimizer, self.model)
         # optimizer = PDSCA(self.model, loss_fn=self.loss, lr=self.optimizer.lr, margin=1.0, epoch_decay=2e-3, weight_decay=1e-5, device=self.device)
-        scheduler = {'scheduler': ReduceLROnPlateau(optimizer, mode='min', factor=0.5), 'monitor': 'val_loss', 'frequency': 5}
+        # scheduler = {'scheduler': CosineAnnearlingLR(optimizer, mode='min', factor=0.5), 'monitor': 'val_loss', 'frequency': 5}
+        scheduler = {'scheduler': ReduceLROnPlateau(optimizer, mode='min', factor=0.1), 'monitor': 'val_loss', 'frequency': 10}
+        # scheduler_aucm = {'scheduler': CosineAnnealingWarmRestarts(optimizer_aucm, T_0=20)}
         
+        # return [optimizer_adam, optimizer_aucm], [scheduler_adam, scheduler_aucm]     
         return [optimizer], [scheduler]     
+        # return optimizer_aucm
+        # return [optimizer], [scheduler]
 
     # def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
     #     optimizer.zero_grad(set_to_none=True)
@@ -690,6 +812,19 @@ class ModelInterface(pl.LightningModule):
 
         pass
 
+    def init_backbone(self):
+        self.backbone = 'retccl'
+        # import models.ResNet as ResNet
+        self.model_ft = ResNet.resnet50(num_classes=128, mlp=False, two_branch=False, normlinear=True)
+        home = Path.cwd().parts[1]
+        # pre_model = 
+        # self.model_ft.fc = nn.Identity()
+        # self.model_ft.load_from_checkpoint(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth', strict=False)
+        self.model_ft.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=False)
+        for param in self.model_ft.parameters():
+            param.requires_grad = False
+        self.model_ft.fc = torch.nn.Identity()
+        self.model_ft.to(self.device)
 
     def instancialize(self, Model, **other_args):
         """ Instancialize a model using the corresponding parameters
@@ -728,31 +863,97 @@ class ModelInterface(pl.LightningModule):
 
         fig_.clf()
 
-    def log_roc_curve(self, probs, target, stage):
+    def log_roc_curve(self, probs, target, stage, comment=''):
 
         fpr_list, tpr_list, thresholds = self.ROC(probs, target)
 
-        plt.figure(1)
+        # self.AUROC(out_probs, target.squeeze())
+
+        fig, ax = plt.subplots(figsize=(6,6))
         if self.n_classes > 2:
+            auroc_score = multiclass_auroc(probs, target.squeeze(), num_classes=self.n_classes, average=None)
             for i in range(len(fpr_list)):
+                
                 fpr = fpr_list[i].cpu().numpy()
                 tpr = tpr_list[i].cpu().numpy()
-                plt.plot(fpr, tpr, label=f'class_{i}')
+                ax.plot(fpr, tpr, label=f'class_{i}, AUROC={auroc_score[i]}')
         else: 
-            print(fpr_list)
+            # print(fpr_list)
+            auroc_score = binary_auroc(probs, target.squeeze())
+
             fpr = fpr_list.cpu().numpy()
             tpr = tpr_list.cpu().numpy()
-            plt.plot(fpr, tpr)
+
+            # df = pd.DataFrame(data = {'fpr': fpr, 'tpr': tpr})
+            # line_plot = sns.lineplot(data=df, x='fpr', y='tpr', label=f'AUROC={auroc_score}', legend='full')
+            # sfig = line_plot.get_figure()
+
+            ax.plot(fpr, tpr, label=f'AUROC={auroc_score}')
+
+
+        ax.set_xlim([0,1])
+        ax.set_ylim([0,1])
+        ax.set_xlabel('False positive rate')
+        ax.set_ylabel('True positive rate')
+        ax.set_title('ROC curve')
+        ax.legend(loc='lower right')
+        # plt.savefig(f'{self.loggers[0].log_dir}/roc.jpg')
+
+        if stage == 'train':
+            self.loggers[0].experiment.add_figure(f'{stage}/ROC_{stage}', plt, self.current_epoch)
+        else:
+            plt.savefig(f'{self.loggers[0].log_dir}/roc_{stage}.jpg', dpi=400)
+            # line_plot.figure.savefig(f'{self.loggers[0].log_dir}/roc_{stage}_sb.jpg')
+
+    def log_pr_curve(self, probs, target, stage, comment=''):
+
+        # fpr_list, tpr_list, thresholds = self.ROC(probs, target)
+        # precision, recall, thresholds = torchmetrics.functional.classification.multiclass_precision_recall_curve(probs, target, num_classes=self.n_classes)
+        # print(precision)
+        # print(recall)
+
+        # baseline = len(target[target==1]) / len(target)
+
+        # plt.figure(1)
+        fig, ax = plt.subplots(figsize=(6,6))
+        
+        if self.n_classes > 2:
+
+            precision, recall, thresholds = multiclass_precision_recall_curve(probs, target, num_classes=self.n_classes)
+            
+            # print(precision)
+            # print(recall)
+            
+            for i in range(len(precision)):
+                pr = precision[i].cpu().numpy()
+                re = recall[i].cpu().numpy()
+                ax.plot(re, pr, label=f'class_{i}')
+                baseline = len(target[target==i]) / len(target)
+                ax.plot([0,1],[baseline, baseline], linestyle='--', label=f'Baseline_{i}')
+
+        else: 
+            # print(fpr_list)
+            precision, recall, thresholds = binary_precision_recall_curve(probs, target)
+            baseline = len(target[target==1]) / len(target)
+            pr = precision.cpu().numpy()
+            re = recall.cpu().numpy()
+            ax.plot(re, pr)
         
-        plt.xlabel('False positive rate')
-        plt.ylabel('True positive rate')
-        plt.title('ROC curve')
-        plt.savefig(f'{self.loggers[0].log_dir}/roc.jpg')
+            ax.plot([0,1], [baseline, baseline], linestyle='--', label='Baseline')
+
+        ax.set_xlim([0,1])
+        ax.set_ylim([0,1])
+
+        ax.set_xlabel('Recall')
+        ax.set_ylabel('Precision')
+        ax.set_title('PR curve')
+        ax.legend(loc='lower right')
+        # plt.savefig(f'{self.loggers[0].log_dir}/pr_{stage}.jpg')
 
         if stage == 'train':
-            self.loggers[0].experiment.add_figure(f'{stage}/ROC', plt, self.current_epoch)
+            self.loggers[0].experiment.add_figure(f'{stage}/PR_{stage}', fig, self.current_epoch)
         else:
-            plt.savefig(f'{self.loggers[0].log_dir}/roc.jpg', dpi=400)
+            fig.savefig(f'{self.loggers[0].log_dir}/pr_{stage}.jpg', dpi=400)
 
     
 
diff --git a/code/models/model_interface_classic.py b/code/models/model_interface_classic.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3eda003007dc350c97daefb18def601eab393b8
--- /dev/null
+++ b/code/models/model_interface_classic.py
@@ -0,0 +1,757 @@
+import sys
+import numpy as np
+import re
+import inspect
+import importlib
+import random
+import pandas as pd
+import seaborn as sns
+from pathlib import Path
+from matplotlib import pyplot as plt
+import cv2
+from PIL import Image
+from pytorch_pretrained_vit import ViT
+
+#---->
+from MyOptimizer import create_optimizer
+from MyLoss import create_loss
+from utils.utils import cross_entropy_torch
+from utils.custom_resnet50 import resnet50_baseline
+
+from timm.loss import AsymmetricLossSingleLabel
+from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
+from libauc.losses import AUCMLoss, AUCM_MultiLabel, CompositionalAUCLoss
+from libauc.optimizers import PESG, PDSCA
+#---->
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchmetrics
+from torchmetrics.functional import stat_scores
+from torch import optim as optim
+from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
+
+from monai.config import KeysCollection
+from monai.data import Dataset, load_decathlon_datalist
+from monai.data.wsi_reader import WSIReader
+from monai.metrics import Cumulative, CumulativeAverage
+from monai.networks.nets import milmodel
+
+# from sklearn.metrics import roc_curve, auc, roc_curve_score
+
+
+#---->
+import pytorch_lightning as pl
+from .vision_transformer import vit_small
+import torchvision
+from torchvision import models
+from torchvision.models import resnet
+from transformers import AutoFeatureExtractor, ViTModel, SwinModel
+
+from pytorch_grad_cam import GradCAM, EigenGradCAM
+from pytorch_grad_cam.utils.image import show_cam_on_image
+from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
+
+from captum.attr import LayerGradCam
+import models.ResNet as ResNet
+
+class FeatureExtractor(pl.LightningDataModule):
+    def __init__(self, model_name, n_classes):
+        self.n_classes = n_classes
+        
+        self.model_ft = ResNet.resnet50(num_classes=self.n_classes, mlp=False, two_branch=False, normlinear=True)
+        home = Path.cwd().parts[1]
+        self.model_ft.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=False)
+        # self.model_ft.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=False)
+        for param in self.model_ft.parameters():
+            param.requires_grad = False
+        self.model_ft.fc = nn.Linear(2048, self.out_features)
+
+    def forward(self,x):
+        return self.model_ft(x)
+
+class ModelInterface_Classic(pl.LightningModule):
+
+    #---->init
+    def __init__(self, model, loss, optimizer, **kargs):
+        super(ModelInterface_Classic, self).__init__()
+        self.save_hyperparameters()
+        self.n_classes = model.n_classes
+        
+        # if self.n_classes>2:
+        #     self.aucm_loss = AUCM_MultiLabel(num_classes = self.n_classes, device=self.device)
+        # else:
+        #     self.aucm_loss = CompositionalAUCLoss()
+        # self.asl = AsymmetricLossSingleLabel()
+        # self.loss = LabelSmoothingCrossEntropy(smoothing=0.1)
+        self.loss = create_loss(loss, model.n_classes)
+
+        # self.loss = 
+        # print(self.model)
+        self.model_name = model.name
+        
+        
+        # self.ecam = EigenGradCAM(model = self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform)
+        self.optimizer = optimizer
+        
+        self.save_path = kargs['log']
+        if Path(self.save_path).parts[3] == 'tcmr':
+            temp = list(Path(self.save_path).parts)
+            # print(temp)
+            temp[3] = 'tcmr_viral'
+            self.save_path = '/'.join(temp)
+
+        # if kargs['task']:
+        #     self.task = kargs['task']
+        self.task = Path(self.save_path).parts[3]
+
+
+        #---->acc
+        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+        self.data_patient = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+        # print(self.experiment)
+        #---->Metrics
+        if self.n_classes > 2: 
+            self.AUROC = torchmetrics.AUROC(task='multiclass', num_classes = self.n_classes, average='macro')
+            self.PRC = torchmetrics.PrecisionRecallCurve(task='multiclass', num_classes = self.n_classes)
+            self.ROC = torchmetrics.ROC(task='multiclass', num_classes=self.n_classes)
+            self.confusion_matrix = torchmetrics.ConfusionMatrix(task='multiclass', num_classes = self.n_classes) 
+            metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(task='multiclass', num_classes = self.n_classes,
+                                                                           average='weighted'),
+                                                     torchmetrics.CohenKappa(task='multiclass', num_classes = self.n_classes),
+                                                     torchmetrics.F1Score(task='multiclass', num_classes = self.n_classes,
+                                                                     average = 'macro'),
+                                                     torchmetrics.Recall(task='multiclass', average = 'macro',
+                                                                         num_classes = self.n_classes),
+                                                     torchmetrics.Precision(task='multiclass', average = 'macro',
+                                                                            num_classes = self.n_classes),
+                                                     torchmetrics.Specificity(task='multiclass', average = 'macro',
+                                                                            num_classes = self.n_classes)])
+                                                                            
+        else : 
+            self.AUROC = torchmetrics.AUROC(task='binary')
+            # self.AUROC = torchmetrics.AUROC(num_classes=self.n_classes, average = 'weighted')
+            self.PRC = torchmetrics.PrecisionRecallCurve(task='binary')
+            self.ROC = torchmetrics.ROC(task='binary')
+            metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(task='binary'),
+                                                     torchmetrics.CohenKappa(task='binary'),
+                                                     torchmetrics.F1Score(task='binary'),
+                                                     torchmetrics.Recall(task='binary'),
+                                                     torchmetrics.Precision(task='binary')
+                                                     ])
+            self.confusion_matrix = torchmetrics.ConfusionMatrix(task='binary')                                                                    
+        self.valid_metrics = metrics.clone(prefix = 'val_')
+        self.valid_patient_metrics = metrics.clone(prefix = 'val_patient_')
+        self.test_metrics = metrics.clone(prefix = 'test_')
+        self.test_patient_metrics = metrics.clone(prefix = 'test_patient')
+
+        #--->random
+        self.shuffle = kargs['data'].data_shuffle
+        self.count = 0
+        # self.model_name = kargs['backbone']
+
+
+        if self.model_name == 'features':
+            self.model = None
+        elif self.model_name == 'inception':
+            # self.model = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True)
+            self.model = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', weights='Inception_V3_Weights.DEFAULT')
+            self.model.aux_logits = False
+            ct = 0
+            for child in self.model.children():
+                ct += 1
+                if ct < 15:
+                    for parameter in child.parameters():
+                        parameter.requires_grad=False
+            # for parameter in self.model.parameters():
+                # parameter.requires_grad = False
+
+            
+            # self.model.AuxLogits.fc = nn.Linear(768, self.n_classes)
+            self.model.fc = nn.Linear(self.model.fc.in_features, self.n_classes)
+        elif self.model_name == 'resnet18':
+            self.model = models.resnet18(weights='IMAGENET1K_V1')
+            # modules = list(resnet18.children())[:-1]
+            # frozen_layers = 8
+            # for child in self.model.children():
+
+            ct = 0
+            for child in self.model.children():
+                ct += 1
+                if ct < 7:
+                    for parameter in child.parameters():
+                        parameter.requires_grad=False
+            self.model.fc = nn.Sequential(
+                nn.Linear(self.model.fc.in_features, self.n_classes),
+            )
+        elif self.model_name == 'retccl':
+            # import models.ResNet as ResNet
+            self.model = ResNet.resnet50(num_classes=self.n_classes, mlp=False, two_branch=False, normlinear=True)
+            home = Path.cwd().parts[1]
+            # pre_model = 
+            # self.model.fc = nn.Identity()
+            # self.model.load_from_checkpoint(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth', strict=False)
+            self.model.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=False)
+            for param in self.model.parameters():
+                param.requires_grad = False
+            self.model.fc = nn.Sequential(
+                nn.Linear(2048, 1024),
+                nn.GELU(),
+                nn.LayerNorm(1024),
+                nn.Linear(1024, 512),
+                nn.GELU(),
+                nn.LayerNorm(512),
+                nn.Linear(512, self.n_classes)
+            )
+        elif self.model_name == 'vit':
+            self.model = ViT('B_32_imagenet1k', pretrained = True) #vis=vis
+            for param in self.model.parameters():
+                param.requires_grad = False
+            self.model.fc = nn.Linear(self.model.fc.in_features, self.n_classes)
+            # print(self.model)
+            # input_size = 384
+
+        elif self.model_name == 'resnet50':
+        
+            self.model = resnet50_baseline(pretrained=True)
+            ct = 0
+            for child in self.model.children():
+                ct += 1
+                if ct < len(list(self.model.children())) - 10:
+                    for parameter in child.parameters():
+                        parameter.requires_grad=False
+            self.model.fc = nn.Sequential(
+                nn.Linear(self.model.fc.in_features, self.n_classes),
+            )
+            
+        elif self.model_name == 'efficientnet':
+            efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_widese_b0', pretrained=True)
+            for param in efficientnet.parameters():
+                param.requires_grad = False
+            # efn = list(efficientnet.children())[:-1]
+            efficientnet.classifier.fc = nn.Linear(1280, self.out_features)
+            self.model = nn.Sequential(
+                efficientnet,
+            )
+        elif self.model_name == 'simple': #mil-ab attention
+            feature_extracting = False
+            self.model = nn.Sequential(
+                nn.Conv2d(3, 20, kernel_size=5),
+                nn.ReLU(),
+                nn.MaxPool2d(2, stride=2),
+                nn.Conv2d(20, 50, kernel_size=5),
+                nn.ReLU(),
+                nn.MaxPool2d(2, stride=2),
+                View((-1, 53*53)),
+                nn.Linear(53*53, self.out_features),
+                nn.ReLU(),
+            )
+
+    # def __build_
+
+    def forward(self, x):
+        # print(x.shape)
+        x = x.squeeze(0)
+        # print(x.shape)
+        return self.model(x)
+
+    def step(self, input):
+
+        input = input.float()
+        # input = input
+        # logits, _ = self(input.contiguous()) 
+        logits = self(input.contiguous())
+        # logits = logits
+        # print(F.softmax(logits))
+        # print(torch.argmax(logits, dim=0))
+        Y_hat = torch.argmax(logits, dim=1)
+        Y_prob = F.softmax(logits, dim=1)
+        # Y_hat = torch.argmax(logits, dim=0).unsqueeze(0)
+        # Y_prob = F.softmax(logits, dim = 0)
+
+        # print(Y_hat)
+        # print(Y_prob)
+
+
+        # Y_hat = torch.argmax(logits, dim=1)
+        # Y_prob = F.softmax(logits, dim=1)
+        
+        return logits, Y_prob, Y_hat
+
+    def training_step(self, batch, batch_idx):
+
+        input, label, _= batch
+
+        # label_filled = torch.full([input.shape[1]], label.item(), device=self.device)
+
+        logits, Y_prob, Y_hat = self.step(input) 
+
+        loss = self.loss(logits, label)
+
+        
+        for y, y_hat in zip(label, Y_hat):    
+            y = int(y)
+            # print(Y_hat)
+            self.data[y]["count"] += 1
+            self.data[y]["correct"] += (int(y_hat) == y)
+
+        self.log('loss', loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1, sync_dist=True)
+
+        return {'loss': loss, 'Y_prob': Y_prob, 'Y_hat': Y_hat, 'label': label} 
+
+    def training_epoch_end(self, training_step_outputs):
+
+        # logits = torch.cat([x['logits'] for x in training_step_outputs], dim = 0)
+        probs = torch.cat([x['Y_prob'] for x in training_step_outputs])
+        max_probs = torch.cat([x['Y_hat'] for x in training_step_outputs])
+        target = torch.cat([x['label'] for x in training_step_outputs])
+
+        # probs = torch.cat([x['Y_prob'] for x in training_step_outputs])
+        # probs = torch.stack([x['Y_prob'] for x in training_step_outputs], dim=0)
+        # max_probs = torch.stack([x['Y_hat'] for x in training_step_outputs])
+        # target = torch.stack([x['label'] for x in training_step_outputs], dim=0).int()
+
+        if self.current_epoch % 5 == 0:
+            for c in range(self.n_classes):
+                count = self.data[c]["count"]
+                correct = self.data[c]["correct"]
+                if count == 0: 
+                    acc = None
+                else:
+                    acc = float(correct) / count
+                print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
+        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+
+        if self.current_epoch % 10 == 0:
+            self.log_confusion_matrix(max_probs, target.squeeze(), stage='train')
+
+        # print(probs)
+        # print(target)
+        # print(probs.shape)
+        # print(target.shape)
+        if self.n_classes <=2:
+            out_probs = probs[:,1] 
+        self.log('Train/auc', self.AUROC(out_probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+
+    def validation_step(self, batch, batch_idx):
+
+        input, label, (wsi_name, tile_name, patient) = batch
+        # label_filled = torch.full([input.shape[1]], label.item(), device=self.device)
+        
+        logits, Y_prob, Y_hat = self.step(input) 
+        logits = logits.detach()
+        Y_prob = Y_prob.detach()
+        Y_hat = Y_hat.detach()
+
+        # loss = self.loss(logits, label)
+        loss = cross_entropy_torch(logits, label)
+
+        for y, y_hat in zip(label, Y_hat):    
+            y = int(y)
+            # print(Y_hat)
+            self.data[y]["count"] += 1
+            self.data[y]["correct"] += (int(y_hat) == y)
+        
+        # self.data[Y]["correct"] += (Y_hat.item() == Y)
+        self.log('val_loss', loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1, sync_dist=True)
+        # print(Y_hat)
+        # print(label)
+        # self.log('val_aucm_loss', aucm_loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1, sync_dist=True)
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': wsi_name, 'patient': patient, 'tile_name': tile_name, 'loss': loss}
+
+
+    def validation_epoch_end(self, val_step_outputs):
+        
+        logits = torch.cat([x['logits'] for x in val_step_outputs], dim = 0)
+        probs = torch.cat([x['Y_prob'] for x in val_step_outputs])
+        max_probs = torch.cat([x['Y_hat'] for x in val_step_outputs])
+        target = torch.cat([x['label'] for x in val_step_outputs])
+        # slide_names = [list(x['name']) for x in val_step_outputs]
+        slide_names = []
+        for x in val_step_outputs:
+            slide_names += list(x['name'])
+        # patients = [list(x['patient']) for x in val_step_outputs]
+        patients = []
+        for x in val_step_outputs:
+            patients += list(x['patient'])
+        tile_name = []
+        for x in val_step_outputs:
+            tile_name += list(x['tile_name'])
+
+        loss = torch.stack([x['loss'] for x in val_step_outputs])
+
+        self.log_dict(self.valid_metrics(max_probs.squeeze(), target.squeeze()),
+                          on_epoch = True, logger = True, sync_dist=True)
+
+        if self.n_classes <=2:
+            out_probs = probs[:,1] 
+        if len(target.unique()) != 1:
+            self.log('val_auc', self.AUROC(out_probs, target).squeeze(), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+            # self.log('val_patient_auc', self.AUROC(patient_score, patient_target), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+        else:    
+            self.log('val_auc', 0.0, prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+
+
+
+        self.log_confusion_matrix(max_probs, target.squeeze(), stage='val')
+
+        #----> log per patient metrics
+        complete_patient_dict = {}
+        patient_list = []            
+        patient_score = []      
+        patient_target = []
+
+        for p, s, pr, t in zip(patients, slide_names, probs, target):
+
+            if p not in complete_patient_dict.keys():
+                complete_patient_dict[p] = {s:[]}
+                patient_target.append(t)
+                
+            elif s not in complete_patient_dict[p].keys():
+                complete_patient_dict[p][s] = []
+            complete_patient_dict[p][s].append(pr)
+            
+
+        for p in complete_patient_dict.keys():
+            score = []
+            for slide in complete_patient_dict[p].keys():
+
+                slide_score = torch.stack(complete_patient_dict[p][slide])
+                if self.n_classes == 2:
+                    positive_positions = (slide_score.argmax(dim=1) == 1).nonzero().squeeze()
+                    if positive_positions.numel() != 0:
+                        slide_score = slide_score[positive_positions]
+                if len(slide_score.shape)>1:
+                    slide_score = torch.mean(slide_score, dim=0)
+
+                score.append(slide_score)
+            score = torch.stack(score)
+            if self.n_classes == 2:
+                positive_positions = (score.argmax(dim=1) == 1).nonzero().squeeze()
+                if positive_positions.numel() != 0:
+                    score = score[positive_positions]
+            if len(score.shape) > 1:
+                score = torch.mean(score, dim=0)
+            patient_score.append(score)    
+
+        patient_score = torch.stack(patient_score)
+        patient_target = torch.stack(patient_target)
+        if self.n_classes <=2:
+            patient_score = patient_score[:,1]
+        if len(patient_target.unique()) != 1:
+            self.log('val_patient_auc', self.AUROC(patient_score, patient_target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+        else:    
+            self.log('val_patient_auc', 0.0, prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+        
+        self.log_dict(self.valid_patient_metrics(patient_score, patient_target),
+                          on_epoch = True, logger = True, sync_dist=True)
+        
+            
+
+        # precision, recall, thresholds = self.PRC(probs, target)
+
+        
+
+        #---->acc log
+        for c in range(self.n_classes):
+            count = self.data[c]["count"]
+            correct = self.data[c]["correct"]
+            if count == 0: 
+                acc = None
+            else:
+                acc = float(correct) / count
+            print('val class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
+        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+        
+        #---->random, if shuffle data, change seed
+        # if self.shuffle == True:
+        #     self.count = self.count+1
+        #     random.seed(self.count*50)
+
+
+
+    def test_step(self, batch, batch_idx):
+
+        input, label, (wsi_name, batch_names, patient) = batch
+        label = label.float()
+        
+        logits, Y_prob, Y_hat = self.step(input) 
+
+        #---->acc log
+        Y = int(label)
+        # Y = torch.argmax(label)
+
+        # print(Y_hat)
+        self.data[Y]["count"] += 1
+        self.data[Y]["correct"] += (int(Y_hat) == Y)
+        # self.data[Y]["correct"] += (Y_hat.item() == Y)
+
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': wsi_name, 'patient': patient}
+
+    def test_epoch_end(self, output_results):
+
+        logits = torch.cat([x['logits'] for x in output_results], dim = 0)
+        probs = torch.cat([x['Y_prob'] for x in output_results])
+        max_probs = torch.cat([x['Y_hat'] for x in output_results])
+        target = torch.cat([x['label'] for x in output_results])
+        slide_names = []
+        for x in output_results:
+            slide_names += list(x['name'])
+        patients = []
+        for x in output_results:
+            patients += list(x['patient'])
+        tile_name = []
+        for x in output_results:
+            tile_name += list(x['tile_name'])
+
+
+
+        # logits = torch.cat([x['logits'] for x in output_results], dim = 0)
+        # probs = torch.cat([x['Y_prob'] for x in output_results])
+        # max_probs = torch.stack([x['Y_hat'] for x in output_results])
+        # target = torch.stack([x['label'] for x in output_results]).int()
+        # slide_names = [x['name'] for x in output_results]
+        # patients = [x['patient'] for x in output_results]
+        
+        self.log_dict(self.test_metrics(max_probs.squeeze(), target.squeeze()),
+                          on_epoch = True, logger = True, sync_dist=True)
+        self.log('test_loss', cross_entropy_torch(logits.squeeze(), target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+
+        # if self.n_classes <=2:
+        #     out_probs = probs[:,1] 
+            # max_probs = max_probs[:,1]
+
+        if self.n_classes <=2:
+            out_probs = probs[:,1] 
+        if len(target.unique()) != 1:
+            
+                self.log('test_auc', self.AUROC(out_probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+            # self.log('val_patient_auc', self.AUROC(patient_score, patient_target), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+        else:    
+            self.log('test_auc', 0.0, prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+
+
+
+        #----> log confusion matrix
+        self.log_confusion_matrix(max_probs.squeeze(), target.squeeze(), stage='test')
+
+        #----> log per patient metrics
+        complete_patient_dict = {}
+        patient_list = []            
+        patient_score = []      
+        patient_target = []
+
+        for p, s, pr, t in zip(patients, slide_names, probs, target):
+
+            if p not in complete_patient_dict.keys():
+                complete_patient_dict[p] = {s:[]}
+                patient_target.append(t)
+                
+            elif s not in complete_patient_dict[p].keys():
+                complete_patient_dict[p][s] = []
+            complete_patient_dict[p][s].append(pr)
+            
+
+        for p in complete_patient_dict.keys():
+            score = []
+            for slide in complete_patient_dict[p].keys():
+
+                slide_score = torch.stack(complete_patient_dict[p][slide])
+                if self.n_classes == 2:
+                    positive_positions = (slide_score.argmax(dim=1) == 1).nonzero().squeeze()
+                    if positive_positions.numel() != 0:
+                        slide_score = slide_score[positive_positions]
+                if len(slide_score.shape)>1:
+                    slide_score = torch.mean(slide_score, dim=0)
+
+                score.append(slide_score)
+            score = torch.stack(score)
+            if self.n_classes == 2:
+                positive_positions = (score.argmax(dim=1) == 1).nonzero().squeeze()
+                if positive_positions.numel() != 0:
+                    score = score[positive_positions]
+            if len(score.shape) > 1:
+                score = torch.mean(score, dim=0)
+            patient_score.append(score)    
+
+        patient_score = torch.stack(patient_score)
+        patient_target = torch.stack(patient_target)
+        if self.n_classes <=2:
+            patient_score = patient_score[:,1]
+        if len(patient_target.unique()) != 1:
+            self.log('test_patient_auc', self.AUROC(patient_score, patient_target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+        else:    
+            self.log('test_patient_auc', 0.0, prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+        
+        self.log_dict(self.valid_patient_metrics(patient_score, patient_target),
+                          on_epoch = True, logger = True, sync_dist=True)
+        
+            
+
+        # precision, recall, thresholds = self.PRC(probs, target)
+
+        
+
+        #---->acc log
+        for c in range(self.n_classes):
+            count = self.data[c]["count"]
+            correct = self.data[c]["correct"]
+            if count == 0: 
+                acc = None
+            else:
+                acc = float(correct) / count
+            print('test class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
+        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+        
+
+        #---->random, if shuffle data, change seed
+        if self.shuffle == True:
+            self.count = self.count+1
+            random.seed(self.count*50)
+
+    def configure_optimizers(self):
+        # optimizer_ft = optim.Adam(self.model_ft.parameters(), lr=self.optimizer.lr*0.1)
+        optimizer = create_optimizer(self.optimizer, self.model)
+        # optimizer_aucm = PESG(self.model, loss_fn=self.aucm_loss, lr=self.optimizer.lr, margin=1.0, epoch_decay=2e-3, weight_decay=1e-5, device=self.device)
+        # optimizer_aucm = PDSCA(self.model, loss_fn=self.aucm_loss, lr=0.005, margin=1.0, epoch_decay=2e-3, weight_decay=1e-4, beta0=0.9, beta1=0.9, device=self.device)
+        # optimizer = PDSCA(self.model, loss_fn=self.loss, lr=self.optimizer.lr, margin=1.0, epoch_decay=2e-3, weight_decay=1e-5, device=self.device)
+        # scheduler = {'scheduler': CosineAnnearlingLR(optimizer, mode='min', factor=0.5), 'monitor': 'val_loss', 'frequency': 5}
+        scheduler = {'scheduler': ReduceLROnPlateau(optimizer, mode='min', factor=0.1), 'monitor': 'val_loss', 'frequency': 5}
+        # scheduler_aucm = {'scheduler': CosineAnnealingWarmRestarts(optimizer_aucm, T_0=20)}
+        
+        # return [optimizer_adam, optimizer_aucm], [scheduler_adam, scheduler_aucm]     
+        # return [optimizer_aucm], [scheduler_aucm]     
+        return [optimizer], [scheduler]
+
+    # def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
+    #     optimizer.zero_grad(set_to_none=True)
+
+    def reshape_transform(self, tensor):
+        # print(tensor.shape)
+        H = tensor.shape[1]
+        _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
+        add_length = _H * _W - H
+        tensor = torch.cat([tensor, tensor[:,:add_length,:]],dim = 1)
+        result = tensor[:, :, :].reshape(tensor.size(0), _H, _W, tensor.size(2))
+        result = result.transpose(2,3).transpose(1,2)
+        # print(result.shape)
+        return result
+
+    def load_model(self):
+        name = self.hparams.model.name
+        # Change the `trans_unet.py` file name to `TransUnet` class name.
+        # Please always name your model file name as `trans_unet.py` and
+        # class name or funciton name corresponding `TransUnet`.
+        if name == 'ViT':
+            self.model = ViT
+
+        if '_' in name:
+            camel_name = ''.join([i.capitalize() for i in name.split('_')])
+        else:
+            camel_name = name
+        try:
+                
+            Model = getattr(importlib.import_module(
+                f'models.{name}'), camel_name)
+        except:
+            raise ValueError('Invalid Module File Name or Invalid Class Name!')
+        self.model = self.instancialize(Model)
+
+        # if backbone == 'retccl':
+
+        #     self.model_ft = ResNet.resnet50(num_classes=self.n_classes, mlp=False, two_branch=False, normlinear=True)
+        #     home = Path.cwd().parts[1]
+        #     # self.model_ft.fc = nn.Identity()
+        #     # self.model_ft.load_from_checkpoint(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth', strict=False)
+        #     self.model_ft.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=False)
+        #     for param in self.model_ft.parameters():
+        #         param.requires_grad = False
+        #     self.model_ft.fc = nn.Linear(2048, self.out_features)
+        
+        # elif backbone == 'resnet50':
+        #     self.model_ft = resnet50_baseline(pretrained=True)
+        #     for param in self.model_ft.parameters():
+        #         param.requires_grad = False
+
+        pass
+
+
+    def instancialize(self, Model, **other_args):
+        """ Instancialize a model using the corresponding parameters
+            from self.hparams dictionary. You can also input any args
+            to overwrite the corresponding value in self.hparams.
+        """
+        class_args = inspect.getargspec(Model.__init__).args[1:]
+        inkeys = self.hparams.model.keys()
+        args1 = {}
+        for arg in class_args:
+            if arg in inkeys:
+                args1[arg] = getattr(self.hparams.model, arg)
+        args1.update(other_args)
+
+
+        return Model(**args1)
+
+    def log_image(self, tensor, stage, name):
+        
+        tile = tile.cpu().numpy().transpose(1,2,0)
+        tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255
+        tile = tile.astype(np.uint8)
+        img = Image.fromarray(tile)
+        self.loggers[0].experiment.add_figure(f'{stage}/{name}', img, self.current_epoch)
+
+
+    def log_confusion_matrix(self, max_probs, target, stage):
+        confmat = self.confusion_matrix(max_probs, target)
+        print(confmat)
+        df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
+        fig_ = sns.heatmap(df_cm, annot=True, fmt='d', cmap='Spectral').get_figure()
+        if stage == 'train':
+            self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+        else:
+            fig_.savefig(f'{self.loggers[0].log_dir}/cm_{stage}.png', dpi=400)
+
+        fig_.clf()
+
+    def log_roc_curve(self, probs, target, stage, comment=''):
+
+        fpr_list, tpr_list, thresholds = self.ROC(probs, target)
+        # print(fpr_list)
+        # print(tpr_list)
+
+        plt.figure(1)
+        if self.n_classes > 2:
+            for i in range(len(fpr_list)):
+                fpr = fpr_list[i].cpu().numpy()
+                tpr = tpr_list[i].cpu().numpy()
+                plt.plot(fpr, tpr, label=f'class_{i}')
+        else: 
+            # print(fpr_list)
+            fpr = fpr_list[0].cpu().numpy()
+            tpr = tpr_list[0].cpu().numpy()
+            plt.plot(fpr, tpr)
+        
+        plt.xlabel('False positive rate')
+        plt.ylabel('True positive rate')
+        plt.title('ROC curve')
+        plt.savefig(f'{self.loggers[0].log_dir}/roc.jpg')
+
+        if stage == 'train':
+            self.loggers[0].experiment.add_figure(f'{stage}/ROC_{comment}', plt, self.current_epoch)
+        else:
+            plt.savefig(f'{self.loggers[0].log_dir}/roc_{comment}.jpg', dpi=400)
+
+    
+
+class View(nn.Module):
+    def __init__(self, shape):
+        super().__init__()
+        self.shape = shape
+
+    def forward(self, input):
+        '''
+        Reshapes the input according to the shape saved in the view data structure.
+        '''
+        out = input.view(*self.shape)
+        return out
+
diff --git a/code/test_visualize.py b/code/test_visualize.py
index 79f814cd75aa73b498cfb140d36ad78a8d0b1568..623066e5ee8489e9d58d8869aaab0186a3689ff9 100644
--- a/code/test_visualize.py
+++ b/code/test_visualize.py
@@ -51,26 +51,63 @@ class custom_test_module(ModelInterface):
     # self.task = kargs['task']    
     # self.task = 'tcmr_viral'
 
+    # def forward(self, x):
+    #     batch_size = x.shape[0]
+    #     bag_size = x.shape[1]
+    #     x = x.view(batch_size*bag_size, x.shape[2], x.shape[3], x.shape[4])
+    #     feats = self.model_ft(x).unsqueeze(0)
+    #     feats = feats.view(batch_size, bag_size, -1)
+    #     return self.model(feats)
+
+
+
+
     def test_step(self, batch, batch_idx):
 
-        torch.set_grad_enabled(True)
+        print('custom: ', self.backbone)
+        print(self.model_ft.device)
+        
 
-        input_data, label, (wsi_name, batch_names, patient) = batch
-        patient = patient[0]
-        wsi_name = wsi_name[0]
-        label = label.float()
-        # logits, Y_prob, Y_hat = self.step(data) 
-        # print(data.shape)
-        input_data = input_data.squeeze(0).float()
-        # print(self.model_ft)
-        # print(self.model)
-        logits, _ = self(input_data)
-        # attn = attn.detach()
-        # logits = logits.detach()
+        torch.set_grad_enabled(True)
 
-        Y = torch.argmax(label)
+        input, label, (wsi_name, patient) = batch
+        
+        print(input.device)
+        # input, label, (wsi_name, batch_names, patient) = batch
+        # label = label.float()
+        # 
+        # feature extraction
+        x = input
+        batch_size = x.shape[0]
+        bag_size = x.shape[1]
+        # new_shape = (batch_size*bag_size, x.shape[2], x.shape[3], x.shape[4])
+        # x = x.view(new_shape)
+        x = x.view(batch_size*bag_size, x.shape[2], x.shape[3], x.shape[4])
+        data_ft = self.model_ft(x).unsqueeze(0)
+        data_ft = data_ft.view(batch_size, bag_size, -1)
+
+        logits = self.model(data_ft) 
         Y_hat = torch.argmax(logits, dim=1)
-        Y_prob = F.softmax(logits, dim=1)
+        Y_prob = F.softmax(logits, dim = 1)
+        # logits, Y_prob, Y_hat = self.model(data_ft) 
+        loss = self.loss(logits, label)
+
+        # input_data, label, (wsi_name, batch_names, patient) = batch
+        # patient = patient[0]
+        # wsi_name = wsi_name[0]
+        # label = label.float()
+        # # logits, Y_prob, Y_hat = self.step(data) 
+        # # print(data.shape)
+        # input_data = input_data.squeeze(0).float()
+        # # print(self.model_ft)
+        # # print(self.model)
+        # logits, _ = self(input_data)
+        # # attn = attn.detach()
+        # # logits = logits.detach()
+
+        # Y = torch.argmax(label)
+        # Y_hat = torch.argmax(logits, dim=1)
+        # Y_prob = F.softmax(logits, dim=1)
 
         
 
@@ -92,13 +129,13 @@ class custom_test_module(ModelInterface):
             target_layers = [self.model.attention_weights]
             self.cam = GradCAM(model = self.model, target_layers = target_layers, use_cuda=True)
 
-        if self.model_ft:
-            data_ft = self.model_ft(input_data).unsqueeze(0).float()
-        else:
-            data_ft = input_data.unsqueeze(0).float()
-        instance_count = input_data.size(0)
+        # if self.model_ft:
+        #     data_ft = self.model_ft(input_data).unsqueeze(0).float()
+        # else:
+        #     data_ft = input_data.unsqueeze(0).float()
+        instance_count = input.size(0)
         # data_ft.requires_grad=True
-        
+        Y = torch.argmax(label)
         target = [ClassifierOutputTarget(Y)]
         # print(target)
         
@@ -111,13 +148,13 @@ class custom_test_module(ModelInterface):
         k = 10
         summed = torch.mean(grayscale_cam, dim=2)
         topk_tiles, topk_indices = torch.topk(summed.squeeze(0), k, dim=0)
-        topk_data = input_data[topk_indices].detach()
+        topk_data = input[topk_indices].detach()
         # print(topk_tiles)
         
         #----------------------------------------------------
         # Log Correct/Count
         #----------------------------------------------------
-        Y = torch.argmax(label)
+        # Y = torch.argmax(label)
         self.data[Y]["count"] += 1
         self.data[Y]["correct"] += (Y_hat.item() == Y)
 
@@ -143,58 +180,84 @@ class custom_test_module(ModelInterface):
 
         logits = torch.cat([x['logits'] for x in output_results], dim = 0)
         probs = torch.cat([x['Y_prob'] for x in output_results])
-        max_probs = torch.stack([x['Y_hat'] for x in output_results])
-        # target = torch.stack([x['label'] for x in output_results], dim = 0)
-        target = torch.stack([x['label'] for x in output_results])
-        # target = torch.argmax(target, dim=1)
-        slide = [x['name'] for x in output_results]
+        # max_probs = torch.stack([x['Y_hat'] for x in output_results])
+        max_probs = torch.cat([x['Y_hat'] for x in output_results])
+        target = torch.cat([x['label'] for x in output_results])
+        slide_names = [x['name'] for x in output_results]
         patients = [x['patient'] for x in output_results]
-        topk_tiles = [x['topk_data'] for x in output_results]
+        loss = torch.stack([x['loss'] for x in output_results])
         #---->
 
-        if len(target.unique()) !=1:
-            auc = self.AUROC(probs, target)
-        else: auc = torch.tensor(0)
-        metrics = self.test_metrics(logits , target)
+        self.log_dict(self.test_metrics(max_probs.squeeze(), target.squeeze()),
+                          on_epoch = True, logger = True, sync_dist=True)
+        self.log('test_loss', loss.mean(), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
 
+        if self.n_classes <=2:
+            out_probs = probs[:,1] 
+        else: out_probs = probs
+            # max_probs = max_probs[:,1]
 
-        # metrics = self.test_metrics(max_probs.squeeze() , torch.argmax(target.squeeze(), dim=1))
-        metrics['test_auc'] = auc
+        if len(target.unique()) != 1:
+                self.log('test_auc', self.AUROC(out_probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+            # self.log('val_patient_auc', self.AUROC(patient_score, patient_target), prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
+        else:    
+            self.log('test_auc', 0.0, prog_bar=True, on_epoch=True, logger=True, sync_dist=True)
 
-        # print(metrics)
-        np_metrics = {k: metrics[k].item() for k in metrics.keys()}
-        # print(np_metrics)
 
-        
+
+        #----> log confusion matrix
+        self.log_confusion_matrix(max_probs, target, stage='test')
+
+        #----> log per patient metrics
         complete_patient_dict = {}
-        '''
-        Patient
-        -> slides:
-            -> SlideName:
-                ->probs = [0.5, 0.5] 
-                ->topk = [10,3,224,224]
-        -> score = []
-        '''
+        patient_list = []            
+        patient_score = []      
+        patient_target = []
+        patient_class_score = 0
 
 
-        for p, s, l, topkt in zip(patients, slide, probs, topk_tiles):
+        for p, s, pr, t in zip(patients, slide_names, probs, target):
+            p = p[0]
+            # print(s[0])
+            # print(pr)
             if p not in complete_patient_dict.keys():
-                complete_patient_dict[p] = {'slides':{}}
-            complete_patient_dict[p]['slides'][s] = {'probs': l, 'topk':topkt}
+                complete_patient_dict[p] = {'scores':[(s[0], pr)], 'patient_score': 0}
+                # print((s,pr))
+                # complete_patient_dict[p]['scores'] = []
+                # print(t)
+                patient_target.append(t)
+            else:
+                complete_patient_dict[p]['scores'].append((s[0], pr))
 
-        patient_list = []            
-        patient_score = []            
         for p in complete_patient_dict.keys():
+            # complete_patient_dict[p] = 0
             score = []
-            
-            for s in complete_patient_dict[p]['slides'].keys():
-                score.append(complete_patient_dict[p]['slides'][s]['probs'])
-            score = torch.mean(torch.stack(score), dim=0) #.cpu().detach().numpy()
-            complete_patient_dict[p]['score'] = score
-            # print(p, score)
-            patient_list.append(p)    
-            patient_score.append(score)    
-
+            for (slide, probs) in complete_patient_dict[p]['scores']:
+                score.append(probs)
+            # print(score)
+            score = torch.stack(score)
+            # print(score)
+            if self.n_classes == 2:
+                positive_positions = (score.argmax(dim=1) == 1).nonzero().squeeze()
+                # print(positive_positions)
+                if positive_positions.numel() != 0:
+                    score = score[positive_positions]
+            if len(score.shape) > 1:
+                score = torch.mean(score, dim=0) #.cpu().detach().numpy()
+
+            patient_score.append(score)  
+            complete_patient_dict[p]['patient_score'] = score
+        correct_patients = []
+        false_patients = []
+
+        for patient, label in zip(complete_patient_dict.keys(), patient_target):
+            if label == 0:
+                p_score =  complete_patient_dict[patient]['patient_score']
+                # print(torch.argmax(patient_score))
+                if torch.argmax(p_score) == label:
+                    correct_patients.append(patient)
+                else: 
+                    false_patients.append(patient)
         # print(patient_list)
         #topk patients: 
 
@@ -231,7 +294,6 @@ class custom_test_module(ModelInterface):
 
             patient_top_slides = {} 
             for p in topk_patients:
-                # print(p)
                 output_dict[class_name][p] = {}
                 output_dict[class_name][p]['Patient_Score'] = complete_patient_dict[p]['score'].cpu().detach().numpy().tolist()
 
@@ -303,7 +365,7 @@ class custom_test_module(ModelInterface):
         #     return coords
 
         home = Path.cwd().parts[1]
-        jpg_dir = f'/{home}/ylan/data/DeepGraft/224_128um_annotated/Aachen_Biopsy_Slides/BLOCKS'
+        jpg_dir = f'/{home}/ylan/data/DeepGraft/224_256uM_annotated/Aachen_Biopsy_Slides/BLOCKS'
 
         coords = batch_names.squeeze()
         data = []
@@ -477,23 +539,37 @@ def main(cfg):
     # cfg.Data.label_file = '/home/ylan/DeepGraft/training_tables/dg_limit_20_split_PAS_HE_Jones_norm_rest.json'
     # cfg.Data.patient_slide = '/homeStor1/ylan/DeepGraft/training_tables/cohort_stain_dict.json'
     # cfg.Data.data_dir = '/homeStor1/ylan/data/DeepGraft/224_128um_v2/'
+    train_classic = False
+    if cfg.Model.name in ['inception', 'resnet18', 'resnet50', 'vit']:
+        train_classic = True
+        use_features = False
     if cfg.Model.backbone == 'features':
-        use_features = True
-    else: use_features = False
+        use_features = False
+        cfg.Model.backbone = 'retccl'
+    else: 
+        use_features = False
+
+    print(cfg.Model.backbone)
+    # use_features = False
+
     DataInterface_dict = {
                 'data_root': cfg.Data.data_dir,
                 'label_path': cfg.Data.label_file,
                 'batch_size': cfg.Data.train_dataloader.batch_size,
                 'num_workers': cfg.Data.train_dataloader.num_workers,
                 'n_classes': cfg.Model.n_classes,
-                'backbone': cfg.Model.backbone,
                 'bag_size': cfg.Data.bag_size,
                 'use_features': use_features,
+                'mixup': cfg.Data.mixup,
+                'aug': cfg.Data.aug,
+                'cache': cfg.Data.cache,
+                'train_classic': train_classic,
+                'model_name': cfg.Model.name,
                 }
 
     dm = MILDataModule(**DataInterface_dict)
     
-
+    # print(cfg.Model.backbone)
     #---->Define Model
     ModelInterface_dict = {'model': cfg.Model,
                             'loss': cfg.Loss,
@@ -503,6 +579,7 @@ def main(cfg):
                             'backbone': cfg.Model.backbone,
                             'task': cfg.task,
                             }
+    # print(ModelInterface_dict)
     # model = ModelInterface(**ModelInterface_dict)
     model = custom_test_module(**ModelInterface_dict)
     # model._fc1 = nn.Sequential(nn.Linear(512, 512), nn.GELU())
@@ -551,8 +628,8 @@ def main(cfg):
     for path in model_paths:
         # with open(f'{log_path}/test_metrics.txt', 'w') as f:
         #     f.write(str(path) + '\n')
-        print(path)
         new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg)
+        new_model.init_backbone()
         new_model.save_path = Path(cfg.log_path) / 'visualization'
         trainer.test(model=new_model, datamodule=dm)
     
@@ -616,10 +693,12 @@ if __name__ == '__main__':
     from models import TransMIL
     from datasets.zarr_feature_dataloader_simple import ZarrFeatureBagLoader
     from datasets.feature_dataloader import FeatureBagLoader
+    from datasets.jpg_dataloader import JPGMILDataloader
     from torch.utils.data import random_split, DataLoader
     import time
     from tqdm import tqdm
     import torchmetrics
+    import models.ResNet as ResNet
 
     device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
     print(device)
@@ -642,7 +721,19 @@ if __name__ == '__main__':
     n_classes = hyper_parameters['model']['n_classes']
 
     # model = TransMIL()
-    model = TransMIL(n_classes).to(device)
+
+    model_ft = ResNet.resnet50(num_classes=128, mlp=False, two_branch=False, normlinear=True)
+    home = Path.cwd().parts[1]
+    # pre_model = 
+    # self.model_ft.fc = nn.Identity()
+    # self.model_ft.load_from_checkpoint(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth', strict=False)
+    model_ft.load_state_dict(torch.load(f'/{home}/ylan/workspace/TransMIL-DeepGraft/code/models/ckpt/retccl_best_ckpt.pth'), strict=False)
+    for param in model_ft.parameters():
+        param.requires_grad = False
+    model_ft.fc = torch.nn.Identity()
+    model_ft.to(device)
+
+    model = TransMIL(n_classes=n_classes, in_features=2048).to(device)
     model_weights = checkpoint['state_dict']
 
     for key in list(model_weights):
@@ -667,9 +758,9 @@ if __name__ == '__main__':
     model.eval()
 
     home = Path.cwd().parts[1]
-    data_root = f'/{home}/ylan/data/DeepGraft/224_128uM_annotated'
+    data_root = f'/{home}/ylan/data/DeepGraft/224_256uM_annotated'
     label_path = f'/{home}/ylan/DeepGraft/training_tables/dg_split_PAS_HE_Jones_norm_rest.json'
-    dataset = FeatureBagLoader(data_root, label_path=label_path, mode='test', cache=False, n_classes=n_classes)
+    dataset = JPGMILDataloader(data_root, label_path=label_path, mode='test', cache=False, n_classes=n_classes, model_name = 'TransMIL')
 
     dl = DataLoader(dataset, batch_size=1, num_workers=8)
 
@@ -693,7 +784,15 @@ if __name__ == '__main__':
         # print(bag.shape)
         bag = bag.unsqueeze(0)
         with torch.cuda.amp.autocast():
-            logits = model(bag)
+            batch_size = bag.shape[0]
+            bag_size = bag.shape[1]
+            bag = bag.view(batch_size*bag_size, bag.shape[2], bag.shape[3], bag.shape[4])
+            feats = self.model_ft(bag).unsqueeze(0)
+            # print(feats.shape)
+            # print(x.shape)
+            # if feats.dim() == 3:
+            feats = feats.view(batch_size, bag_size, -1)
+            logits = model(feats)
         Y_hat = torch.argmax(logits, dim=1)
         Y_prob = F.softmax(logits, dim = 1)
 
diff --git a/code/train.py b/code/train.py
index 53ab165add41af536e18863583798d424df59ec2..984a609e0d4b88dd76232fc3b3fdbbc4aea22a09 100644
--- a/code/train.py
+++ b/code/train.py
@@ -8,6 +8,7 @@ from sklearn.model_selection import KFold
 from datasets.data_interface import MILDataModule, CrossVal_MILDataModule
 # from datasets.data_interface import DataInterface, MILDataModule, CrossVal_MILDataModule
 from models.model_interface import ModelInterface
+from models.model_interface_classic import ModelInterface_Classic
 from models.model_interface_dtfd import ModelInterface_DTFD
 import models.vision_transformer as vits
 from utils.utils import *
@@ -70,9 +71,9 @@ def make_parse():
     parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str)
     parser.add_argument('--fold', default = 0)
     parser.add_argument('--bag_size', default = 1024, type=int)
+    # parser.add_argument('--batch_size', default = 1, type=int)
     parser.add_argument('--resume_training', action='store_true')
     parser.add_argument('--label_file', type=str)
-    # parser.add_argument('--ckpt_path', default = , type=str)
     
 
     args = parser.parse_args()
@@ -103,6 +104,13 @@ def main(cfg):
     #             'dataset_cfg': cfg.Data,}
     # dm = DataInterface(**DataInterface_dict)
     home = Path.cwd().parts[1]
+
+    train_classic = False
+    if cfg.Model.name in ['inception', 'resnet18', 'resnet50', 'vit']:
+        train_classic = True
+        use_features = False
+
+
     if cfg.Model.backbone == 'features':
         use_features = True
     else: use_features = False
@@ -116,6 +124,9 @@ def main(cfg):
                 'use_features': use_features,
                 'mixup': cfg.Data.mixup,
                 'aug': cfg.Data.aug,
+                'cache': cfg.Data.cache,
+                'train_classic': train_classic,
+                'model_name': cfg.Model.name,
                 }
 
     if cfg.Data.cross_val:
@@ -124,6 +135,7 @@ def main(cfg):
     
 
     #---->Define Model
+    
     ModelInterface_dict = {'model': cfg.Model,
                             'loss': cfg.Loss,
                             'optimizer': cfg.Optimizer,
@@ -131,9 +143,14 @@ def main(cfg):
                             'log': cfg.log_path,
                             'backbone': cfg.Model.backbone,
                             'task': cfg.task,
+                            'in_features': cfg.Model.in_features,
+                            'out_features': cfg.Model.out_features,
                             }
-    if cfg.Model.name == 'DTFDMIL':
-        model = ModelInterface_DTFD(**ModelInterface_dict)
+
+    if train_classic:
+        model = ModelInterface_Classic(**ModelInterface_dict)
+    # elif cfg.Model.name == 'DTFDMIL':
+    #     model = ModelInterface_DTFD(**ModelInterface_dict)
     else:
         model = ModelInterface(**ModelInterface_dict)
     
@@ -169,7 +186,7 @@ def main(cfg):
             logger=cfg.load_loggers,
             callbacks=cfg.callbacks,
             max_epochs= cfg.General.epochs,
-            min_epochs = 100,
+            min_epochs = 150,
 
             # gpus=cfg.General.gpus,
             accelerator='gpu',
@@ -179,10 +196,12 @@ def main(cfg):
             precision=cfg.General.precision,  
             accumulate_grad_batches=cfg.General.grad_acc,
             gradient_clip_val=0.0,
+            log_every_n_steps=10,
             # fast_dev_run = True,
             # limit_train_batches=1,
             
             # deterministic=True,
+            # num_sanity_val_steps=0,
             check_val_every_n_epoch=1,
         )
     # print(cfg.log_path)
@@ -192,21 +211,16 @@ def main(cfg):
 
     # home = Path.cwd()[0]
 
-    copy_path = Path(trainer.loggers[0].log_dir) / 'code'
-    copy_path.mkdir(parents=True, exist_ok=True)
-    copy_origin = '/' / Path('/'.join(cfg.log_path.parts[1:5])) / 'code'
-    # print(copy_path)
-    # print(copy_origin)
-    shutil.copytree(copy_origin, copy_path, dirs_exist_ok=True)
+    if cfg.General.server == 'train':
 
-    
-    # print(trainer.loggers[0].log_dir)
+        copy_path = Path(trainer.loggers[0].log_dir) / 'code'
+        copy_path.mkdir(parents=True, exist_ok=True)
+        copy_origin = '/' / Path('/'.join(cfg.log_path.parts[1:5])) / 'code'
+        shutil.copytree(copy_origin, copy_path, dirs_exist_ok=True)
 
     #---->train or test
     if cfg.resume_training:
         last_ckpt = log_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.version}' / 'last.ckpt'
-        # model = model.load_from_checkpoint(last_ckpt)
-        # trainer.fit(model, dm) #, datamodule = dm
         trainer.fit(model = model, ckpt_path=last_ckpt) #, datamodule = dm
 
     if cfg.General.server == 'train':
@@ -222,16 +236,9 @@ def main(cfg):
     else:
         log_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.version}'/'checkpoints' 
 
-        print(log_path)
-        test_path = Path(log_path) / 'test'
-        # for n in range(cfg.Model.n_classes):
-        #     n_output_path = test_path / str(n)
-        #     n_output_path.mkdir(parents=True, exist_ok=True)
-        # print(cfg.log_path)
         model_paths = list(log_path.glob('*.ckpt'))
-        # print(model_paths)
-        # print(cfg.epoch)
-        # model_paths = [str(model_path) for model_path in model_paths if 'epoch' in str(model_path)]
+
+
         if cfg.epoch == 'last':
             model_paths = [str(model_path) for model_path in model_paths if f'last' in str(model_path)]
         elif int(cfg.epoch) < 10:
@@ -242,9 +249,9 @@ def main(cfg):
         # model_paths = [f'{log_path}/epoch=279-val_loss=0.4009.ckpt']
 
         for path in model_paths:
-            # print(path)
-            new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg)
-            trainer.test(model=new_model, datamodule=dm)
+            print(path)
+            model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg)
+            trainer.test(model=model, datamodule=dm)
 
 
 def check_home(cfg):
diff --git a/code/utils/__pycache__/utils.cpython-39.pyc b/code/utils/__pycache__/utils.cpython-39.pyc
index df4436cad807f2a148eddd27875e344eb4dcbe14..7323a9ff0b5a105e42a911fa4478357cbd7b97a1 100644
Binary files a/code/utils/__pycache__/utils.cpython-39.pyc and b/code/utils/__pycache__/utils.cpython-39.pyc differ
diff --git a/code/utils/utils.py b/code/utils/utils.py
old mode 100755
new mode 100644
index 9a756d58876ee847677f414a99318cf426fcf75d..f7781e88e27d288b4a17ed38036aa8fb6f63e4bf
--- a/code/utils/utils.py
+++ b/code/utils/utils.py
@@ -39,7 +39,7 @@ def load_loggers(cfg):
     
     
     #---->TensorBoard
-    if cfg.stage != 'test':
+    if cfg.General.server != 'test':
         
         tb_logger = pl_loggers.TensorBoardLogger(cfg.log_path,
                                                   # version = f'fold{cfg.Data.fold}'
@@ -51,13 +51,16 @@ def load_loggers(cfg):
                                         ) # version = f'fold{cfg.Data.fold}', 
         # print(csv_logger.version)
     else:  
-        cfg.log_path = Path(cfg.log_path) / f'test'
+        cfg.log_path = Path(cfg.log_path)
+        print('cfg.log_path: ', cfg.log_path)
+
         tb_logger = pl_loggers.TensorBoardLogger(cfg.log_path,
-                                                version = f'test',
+                                                version = cfg.version,
+                                                sub_dir = f'test_e{cfg.epoch}',
                                                 log_graph = True, default_hp_metric = False)
         #---->CSV
         csv_logger = pl_loggers.CSVLogger(cfg.log_path,
-                                        version = f'test', )
+                                        version = cfg.version, )
                               
     
     print(f'---->Log dir: {cfg.log_path}')
@@ -79,11 +82,11 @@ def load_callbacks(cfg, save_path):
     output_path.mkdir(exist_ok=True, parents=True)
 
     early_stop_callback = EarlyStopping(
-        monitor='val_auc',
+        monitor='val_loss',
         min_delta=0.00,
         patience=cfg.General.patience,
         verbose=True,
-        mode='max'
+        mode='min'
     )
 
     Mycallbacks.append(early_stop_callback)
@@ -105,7 +108,7 @@ def load_callbacks(cfg, save_path):
         # save_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.resume_version}' / last.ckpt
         Mycallbacks.append(ModelCheckpoint(monitor = 'val_loss',
                                          dirpath = str(output_path),
-                                         filename = '{epoch:02d}-{val_loss:.4f}-{val_auc: .4f}-{val_patient_auc}',
+                                         filename = '{epoch:02d}-{val_loss:.4f}-{val_auc: .4f}-{val_patient_auc:.4f}',
                                          verbose = True,
                                          save_last = True,
                                          save_top_k = 2,
@@ -113,7 +116,7 @@ def load_callbacks(cfg, save_path):
                                          save_weights_only = True))
         Mycallbacks.append(ModelCheckpoint(monitor = 'val_auc',
                                          dirpath = str(output_path),
-                                         filename = '{epoch:02d}-{val_loss:.4f}-{val_auc:.4f}-{val_patient_auc}',
+                                         filename = '{epoch:02d}-{val_loss:.4f}-{val_auc:.4f}-{val_patient_auc: .4f}',
                                          verbose = True,
                                          save_last = True,
                                          save_top_k = 2,
@@ -121,7 +124,7 @@ def load_callbacks(cfg, save_path):
                                          save_weights_only = True))
         Mycallbacks.append(ModelCheckpoint(monitor = 'val_patient_auc',
                                          dirpath = str(output_path),
-                                         filename = '{epoch:02d}-{val_loss:.4f}-{val_auc:.4f}-{val_patient_auc}',
+                                         filename = '{epoch:02d}-{val_loss:.4f}-{val_auc:.4f}-{val_patient_auc:.4f}',
                                          verbose = True,
                                          save_last = True,
                                          save_top_k = 2,
diff --git a/project_plan.md b/project_plan.md
index b0e8379c0f5872e1e70d9095c347ab37d9fc19da..0fde7bc92b10e2e6060c99c40ae3f9b7ce2d4d50 100644
--- a/project_plan.md
+++ b/project_plan.md
@@ -7,11 +7,11 @@ With this project, we aim to esatablish a benchmark for weakly supervised deep l
 
 #### Original Lancet Set:
 
-    * Training:
-        * AMS: 1130 Biopsies (3390 WSI)
-        * Utrecht: 717 Biopsies (2151WSI)
-    * Testing:
-        * Aachen: 101 Biopsies (303 WSI)
+* Training:
+    * AMS: 1130 Biopsies (3390 WSI)
+    * Utrecht: 717 Biopsies (2151WSI)
+* Testing:
+    * Aachen: 101 Biopsies (303 WSI)
 
 
 #### Extended:
@@ -23,80 +23,80 @@ With this project, we aim to esatablish a benchmark for weakly supervised deep l
 
 ## Models:
 
-    For our Benchmark, we chose the following models: 
+For our Benchmark, we chose the following models: 
 
-    - AttentionMIL
-    - Resnet18/50
-    - ViT
-    - CLAM
-    - TransMIL
-    - Monai MIL (optional)
+- AttentionMIL
+- Resnet18/50
+- ViT
+- CLAM
+- TransMIL
+- Monai MIL (optional)
 
-    Resnet18 and Resnet50 are basic CNNs that can be applied for a variety of tasks. Although domain or task specific architectures mostly outperform them, they remain a good baseline for comparison. 
+Resnet18 and Resnet50 are basic CNNs that can be applied for a variety of tasks. Although domain or task specific architectures mostly outperform them, they remain a good baseline for comparison. 
 
-    The vision transformer is the first transformer based model that was adapted to computer vision tasks. Benchmarking on ViT can provide more insight on the performance of generic transformer based models on multiple instance learning. 
+The vision transformer is the first transformer based model that was adapted to computer vision tasks. Benchmarking on ViT can provide more insight on the performance of generic transformer based models on multiple instance learning. 
 
-    The AttentionMIL was the first simple, yet relatively successful deep MIL model and should be used as a baseline for benchmarking MIL methods. 
+The AttentionMIL was the first simple, yet relatively successful deep MIL model and should be used as a baseline for benchmarking MIL methods. 
 
-    CLAM is a recent model proposed by Mahmood lab which was explicitely trained for histopathological whole slide images and should be used as a baseline for benchmarking MIL methods in histopathology. 
+CLAM is a recent model proposed by Mahmood lab which was explicitely trained for histopathological whole slide images and should be used as a baseline for benchmarking MIL methods in histopathology. 
 
-    TransMIL is another model proposed by Shao et al, which achieved SOTA on histopathological WSI classification tasks using MIL. It was benchmarked on TCGA and compared to CLAM and AttMIL. It utilizes the self-attention module from transformer models.
+TransMIL is another model proposed by Shao et al, which achieved SOTA on histopathological WSI classification tasks using MIL. It was benchmarked on TCGA and compared to CLAM and AttMIL. It utilizes the self-attention module from transformer models.
 
-    Monai MIL (not official name) is a MIL architecture proposed by Myronenk et al (Nvidia). It applies the self-attention mechanism as well. It is included because it shows promising results and it's included in MONAI. 
+Monai MIL (not official name) is a MIL architecture proposed by Myronenk et al (Nvidia). It applies the self-attention mechanism as well. It is included because it shows promising results and it's included in MONAI. 
 
 ## Tasks:
 
-    The Original tasks mimic the ones published in the original DeepGraft Lancet paper. 
-    Before we go for more challenging tasks (future tasks), we want to establish that our models outperform the simpler approach from the previous paper and that going for MIL in this setting is indeed profitable. 
+The Original tasks mimic the ones published in the original DeepGraft Lancet paper. 
+Before we go for more challenging tasks (future tasks), we want to establish that our models outperform the simpler approach from the previous paper and that going for MIL in this setting is indeed profitable. 
 
-    All available classes: 
-        * Normal
-        * TCMR
-        * ABMR
-        * Mixed
-        * Viral
-        * Other
+All available classes: 
+    * Normal
+    * TCMR
+    * ABMR
+    * Mixed
+    * Viral
+    * Other
 
 #### Original:
 
-    The explicit classes are simplified/grouped together such as this: 
-    Diseased = all classes other than Normal 
-    Rejection = TCMR, ABMR, Mixed 
+The explicit classes are simplified/grouped together such as this: 
+Diseased = all classes other than Normal 
+Rejection = TCMR, ABMR, Mixed 
 
-    - (1) Normal vs Diseased (all other classes)
-    - (2) Rejection vs (Viral + Others)
-    - (3) Normal vs Rejection vs (Viral + Others)
+- (1) Normal vs Diseased (all other classes)
+- (2) Rejection vs (Viral + Others)
+- (3) Normal vs Rejection vs (Viral + Others)
 
 #### Future:
 
-    After validating Original tasks, the next step is to challenge the models by attempting more complicated tasks. 
-    These experiments may vary depending on the results from previous experiments
+After validating Original tasks, the next step is to challenge the models by attempting more complicated tasks. 
+These experiments may vary depending on the results from previous experiments
 
-    - (4) Normal vs TCMR vs Mixed vs ABMR vs Viral vs Others
-    - (5) TCMR vs Mixed vs ABMR
+- (4) Normal vs TCMR vs Mixed vs ABMR vs Viral vs Others
+- (5) TCMR vs Mixed vs ABMR
 
 ## Plan:
 
-    1. Train models for current tasks on AMS+Utrecht -> Validate on Aachen
+1. Train models for current tasks on AMS+Utrecht -> Validate on Aachen
 
-    2. Visualization, AUC Curves
+2. Visualization, AUC Curves
 
-    3. Train best model on extended training set (AMS+Utrecht+Leuven) (Tasks 1,2,3) -> Validate on Aachen_extended
-        - Investigate if a larger training cohort increases performance
-    4. Train best model on extended dataset on future tasks (Task 4, 5)
+3. Train best model on extended training set (AMS+Utrecht+Leuven) (Tasks 1,2,3) -> Validate on Aachen_extended
+    - Investigate if a larger training cohort increases performance
+4. Train best model on extended dataset on future tasks (Task 4, 5)
 
 
-    Notes: 
-        * Resnet18, ViT and CLAM are all trained on HIA (Training Framework from Kather / Narmin)
+Notes: 
+    * Resnet18, ViT and CLAM are all trained on HIA (Training Framework from Kather / Narmin)
     
 
 ## Status: 
 
-        - Resnet18: Trained on all tasks via HIA  
-        - Vit: Trained on all tasks via HIA 
-        - CLAM: Trained on (1) via HIA 
-        - TransMIL: Trained, but overfitting
-            - Check if the problems are not on model side by evaluating on RCC data. 
-            - (mixing in 10 slides from Aachen increases auc performance from 0.7 to 0.89)
-        - AttentionMIL: WIP
-        - Monai MIL: WIP
+- Resnet18: Trained on all tasks via HIA  
+- Vit: Trained on all tasks via HIA 
+- CLAM: Trained on (1) via HIA 
+- TransMIL: Trained, but overfitting
+    - Check if the problems are not on model side by evaluating on RCC data. 
+    - (mixing in 10 slides from Aachen increases auc performance from 0.7 to 0.89)
+- AttentionMIL: WIP
+- Monai MIL: WIP