diff --git a/.gitignore b/.gitignore
index 4cf8dd15619e7c11d325ae0eb80bba874a99f06d..c9e4fdabe0660f6872b93a3d8a59aa1750a671e0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,3 @@
-logs/*
\ No newline at end of file
+logs/*
+lightning_logs/*
+test/*
\ No newline at end of file
diff --git a/DeepGraft/AttMIL_resnet18_debug.yaml b/DeepGraft/AttMIL_resnet18_debug.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..03ebd7e60a5af9d32bf1442455e24fc9528557f8
--- /dev/null
+++ b/DeepGraft/AttMIL_resnet18_debug.yaml
@@ -0,0 +1,51 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [1]
+    epochs: &epoch 1 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 2
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_debug.json'
+    fold: 0
+    nfold: 2
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+        
+
+Model:
+    name: AttMIL
+    n_classes: 2
+    backbone: simple
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/AttMIL_simple_no_other.yaml b/DeepGraft/AttMIL_simple_no_other.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ae90a80ee152a2bd9d5b04c059fff38b1e77c06f
--- /dev/null
+++ b/DeepGraft/AttMIL_simple_no_other.yaml
@@ -0,0 +1,49 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [3]
+    epochs: &epoch 200 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 20
+    server: train #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_other.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: AttMIL
+    n_classes: 5
+    backbone: simple
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/AttMIL_simple_no_viral.yaml b/DeepGraft/AttMIL_simple_no_viral.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..37ee07479e014d9277f582196eb5a50f4d74e2de
--- /dev/null
+++ b/DeepGraft/AttMIL_simple_no_viral.yaml
@@ -0,0 +1,48 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [3]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: AttMIL
+    n_classes: 4
+    backbone: simple
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/AttMIL_simple_tcmr_viral.yaml b/DeepGraft/AttMIL_simple_tcmr_viral.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c982d3ad1bae365a2497a599a805dadf9874a9c2
--- /dev/null
+++ b/DeepGraft/AttMIL_simple_tcmr_viral.yaml
@@ -0,0 +1,49 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [3]
+    epochs: &epoch 300 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 20
+    server: train #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: AttMIL
+    n_classes: 2
+    backbone: simple
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_efficientnet_no_other.yaml b/DeepGraft/TransMIL_efficientnet_no_other.yaml
index 79d8ea88865ebddd6a918bdc4b9435b6ba973ff1..7687a0c971796c6b31d199961343047e612c91dd 100644
--- a/DeepGraft/TransMIL_efficientnet_no_other.yaml
+++ b/DeepGraft/TransMIL_efficientnet_no_other.yaml
@@ -6,10 +6,10 @@ General:
     precision: 16 
     multi_gpu_mode: dp
     gpus: [0]
-    epochs: &epoch 1000 
+    epochs: &epoch 200 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 20
     server: test #train #test
     log_path: logs/
 
diff --git a/DeepGraft/TransMIL_efficientnet_no_viral.yaml b/DeepGraft/TransMIL_efficientnet_no_viral.yaml
index 8780060ebd1475b273b06498800523d7108085b1..98fe3778c9528e38d027b1f86d4d4b18631b0fe8 100644
--- a/DeepGraft/TransMIL_efficientnet_no_viral.yaml
+++ b/DeepGraft/TransMIL_efficientnet_no_viral.yaml
@@ -5,11 +5,11 @@ General:
     amp_level: O2
     precision: 16 
     multi_gpu_mode: dp
-    gpus: [3]
-    epochs: &epoch 500 
+    gpus: [0, 2]
+    epochs: &epoch 200 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 20
     server: test #train #test
     log_path: logs/
 
diff --git a/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml
index f69b5bfadaa3023c9b6113e59681daee73b29fe2..52230329255872d150dbc4d0552d1265dc3abc80 100644
--- a/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml
+++ b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml
@@ -5,7 +5,7 @@ General:
     amp_level: O2
     precision: 16 
     multi_gpu_mode: dp
-    gpus: [3]
+    gpus: [0]
     epochs: &epoch 500 
     grad_acc: 2
     frozen_bn: False
@@ -16,10 +16,11 @@ General:
 Data:
     dataset_name: custom
     data_shuffle: False
-    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    data_dir: '/home/ylan/data/DeepGraft/256_256um/'
     label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
     fold: 1
-    nfold: 4
+    nfold: 3
+    cross_val: True
 
     train_dataloader:
         batch_size: 1 
@@ -33,6 +34,7 @@ Model:
     name: TransMIL
     n_classes: 2
     backbone: efficientnet
+    in_features: 512
 
 
 Optimizer:
diff --git a/DeepGraft/TransMIL_debug.yaml b/DeepGraft/TransMIL_resnet18_debug.yaml
similarity index 91%
rename from DeepGraft/TransMIL_debug.yaml
rename to DeepGraft/TransMIL_resnet18_debug.yaml
index d83ce0d902228644d4fdd8a3059cd4b135a69fde..29bfa1e34ef4250bedd587bb736a9309099b1c0c 100644
--- a/DeepGraft/TransMIL_debug.yaml
+++ b/DeepGraft/TransMIL_resnet18_debug.yaml
@@ -6,10 +6,10 @@ General:
     precision: 16 
     multi_gpu_mode: dp
     gpus: [1]
-    epochs: &epoch 200 
+    epochs: &epoch 1 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 2
     server: test #train #test
     log_path: logs/
 
@@ -19,7 +19,8 @@ Data:
     data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
     label_file: '/home/ylan/DeepGraft/training_tables/split_debug.json'
     fold: 0
-    nfold: 4
+    nfold: 2
+    cross_val: True
 
     train_dataloader:
         batch_size: 1 
diff --git a/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
index a756616cd4a128874731b17353d108856f72e9f4..f6e469763b0a6df74a2ce2306ae7e94f53d9770b 100644
--- a/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+++ b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
@@ -5,8 +5,8 @@ General:
     amp_level: O2
     precision: 16 
     multi_gpu_mode: dp
-    gpus: [3]
-    epochs: &epoch 500 
+    gpus: [0]
+    epochs: &epoch 200 
     grad_acc: 2
     frozen_bn: False
     patience: 50
@@ -19,7 +19,8 @@ Data:
     data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
     label_file: '/home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json'
     fold: 1
-    nfold: 4
+    nfold: 3
+    cross_val: True
 
     train_dataloader:
         batch_size: 1 
diff --git a/__pycache__/train_loop.cpython-39.pyc b/__pycache__/train_loop.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65c5ca839312787cce6aed19f346ae40a7c04d83
Binary files /dev/null and b/__pycache__/train_loop.cpython-39.pyc differ
diff --git a/datasets/__pycache__/custom_dataloader.cpython-39.pyc b/datasets/__pycache__/custom_dataloader.cpython-39.pyc
index 4a200bbb7d34328019d9bb9e604b0524127ae863..99e793aa4b84b2a5e9ee1882437346e6e4a33002 100644
Binary files a/datasets/__pycache__/custom_dataloader.cpython-39.pyc and b/datasets/__pycache__/custom_dataloader.cpython-39.pyc differ
diff --git a/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc b/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..20aefffd1a14afe3c499f646e9eb674810896281
Binary files /dev/null and b/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc differ
diff --git a/datasets/__pycache__/data_interface.cpython-39.pyc b/datasets/__pycache__/data_interface.cpython-39.pyc
index 9550db1509e8d477a1c15534a76c6f87976fbec4..59584d31b16ee5ba7d08220a2174fb5e79e9aaed 100644
Binary files a/datasets/__pycache__/data_interface.cpython-39.pyc and b/datasets/__pycache__/data_interface.cpython-39.pyc differ
diff --git a/datasets/custom_dataloader.py b/datasets/custom_dataloader.py
index 02850f5db5e263a9d5cdbbfff923952dd5dbfa52..cf65534d749f1cdc764bbdd8193245043c7853f6 100644
--- a/datasets/custom_dataloader.py
+++ b/datasets/custom_dataloader.py
@@ -41,7 +41,7 @@ class HDF5MILDataloader(data.Dataset):
         data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
 
     """
-    def __init__(self, file_path, label_path, mode, n_classes, backbone=None, load_data=False, data_cache_size=20, bag_size=1024):
+    def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=10, bag_size=1024):
         super().__init__()
 
         self.data_info = []
@@ -55,7 +55,6 @@ class HDF5MILDataloader(data.Dataset):
         self.label_path = label_path
         self.n_classes = n_classes
         self.bag_size = bag_size
-        self.backbone = backbone
         # self.label_file = label_path
         recursive = True
         
@@ -134,10 +133,6 @@ class HDF5MILDataloader(data.Dataset):
             RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)),
             transforms.ToTensor()
         ])
-        if self.backbone == 'dino':
-            self.feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
-
-        # self._add_data_infos(load_data)
 
     def __getitem__(self, index):
         # get data
@@ -150,9 +145,6 @@ class HDF5MILDataloader(data.Dataset):
             # print(img.shape)
             for img in batch: # expects numpy 
                 img = img.numpy().astype(np.uint8)
-                if self.backbone == 'dino':
-                    img = self.feature_extractor(images=img, return_tensors='pt')
-                    # img = self.resize_transforms(img)
                 img = seq_img_d.augment_image(img)
                 img = self.val_transforms(img)
                 out_batch.append(img)
@@ -160,23 +152,24 @@ class HDF5MILDataloader(data.Dataset):
         else:
             for img in batch:
                 img = img.numpy().astype(np.uint8)
-                if self.backbone == 'dino':
-                    img = self.feature_extractor(images=img, return_tensors='pt')
-                    img = self.resize_transforms(img)
-                
                 img = self.val_transforms(img)
                 out_batch.append(img)
 
-        if len(out_batch) == 0:
-            # print(name)
-            out_batch = torch.randn(self.bag_size,3,256,256)
-        else: out_batch = torch.stack(out_batch)
+        # if len(out_batch) == 0:
+        #     # print(name)
+        #     out_batch = torch.randn(self.bag_size,3,256,256)
+        # else: 
+        out_batch = torch.stack(out_batch)
         # print(out_batch.shape)
         # out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch
-
+        # print(out_batch.shape)
+        if out_batch.shape != torch.Size([self.bag_size, 256, 256, 3]) and out_batch.shape != torch.Size([self.bag_size, 3,256,256]):
+            print(name)
+            print(out_batch.shape)
+            out_batch = torch.permute(out_batch, (0, 2,1,3))
         label = torch.as_tensor(label)
         label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
-        return out_batch, label, name
+        return out_batch, label, name #, name_batch
 
     def __len__(self):
         return len(self.data_info)
@@ -184,7 +177,9 @@ class HDF5MILDataloader(data.Dataset):
     def _add_data_infos(self, file_path, load_data):
         wsi_name = Path(file_path).stem
         if wsi_name in self.slideLabelDict:
+            # if wsi_name[:2] != 'RU': #skip RU because of container problems in dataset
             label = self.slideLabelDict[wsi_name]
+            # print(wsi_name)
             idx = -1
             self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx})
 
@@ -195,14 +190,29 @@ class HDF5MILDataloader(data.Dataset):
         """
         with h5py.File(file_path, 'r') as h5_file:
             wsi_batch = []
+            tile_names = []
             for tile in h5_file.keys():
+                
                 img = h5_file[tile][:]
                 img = img.astype(np.uint8)
                 img = torch.from_numpy(img)
                 # img = self.resize_transforms(img)
+                
                 wsi_batch.append(img)
-            wsi_batch = torch.stack(wsi_batch)
-            wsi_batch, _ = to_fixed_size_bag(wsi_batch, self.bag_size)
+                tile_names.append(tile)
+
+            #     print('Empty Container: ', file_path) #/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/RU0248_PASD_jke_PASD_20200201_195900_BIG.hdf5
+                
+            if wsi_batch:
+                wsi_batch = torch.stack(wsi_batch)
+            else: 
+                print('Empty Container: ', file_path)
+                wsi_batch = torch.randn(self.bag_size,3,256,256)
+
+            if wsi_batch.shape[1:] != torch.Size([3, 256, 256]) and wsi_batch.shape[1:] != torch.Size([256, 256, 3]):
+                print(file_path)
+                print(wsi_batch.shape)
+            wsi_batch, name_batch = to_fixed_size_bag(wsi_batch, self.bag_size)
             idx = self._add_to_cache(wsi_batch, file_path)
             file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path)
             self.data_info[file_idx + idx]['cache_idx'] = idx
@@ -461,16 +471,19 @@ class RandomHueSaturationValue(object):
             img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
         return img #, lbl
 
-def to_fixed_size_bag(bag: torch.Tensor, bag_size: int = 512):
+def to_fixed_size_bag(bag, bag_size: int = 512):
 
     # get up to bag_size elements
     bag_idxs = torch.randperm(bag.shape[0])[:bag_size]
     bag_samples = bag[bag_idxs]
+    # bag_sample_names = [bag_names[i] for i in bag_idxs]
     
     # zero-pad if we don't have enough samples
     zero_padded = torch.cat((bag_samples,
                             torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+    # zero_padded_names = bag_sample_names + ['']*(bag_size - len(bag_sample_names))
     return zero_padded, min(bag_size, len(bag))
+    # return zero_padded, zero_padded_names, min(bag_size, len(bag))
 
 
 class RandomHueSaturationValue(object):
@@ -510,11 +523,11 @@ if __name__ == '__main__':
     train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
     data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
     # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
-    label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_no_other.json'
     output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/augments'
     os.makedirs(output_path, exist_ok=True)
 
-    n_classes = 2
+    n_classes = 5
 
     dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes, bag_size=20)
     # print(dataset.dataset)
@@ -528,15 +541,19 @@ if __name__ == '__main__':
 
     # print(len(dataset))
     # # x = 0
+    #/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/RU0248_PASD_jke_PASD_20200201_195900_BIG.hdf5
     c = 0
     label_count = [0] *n_classes
     for item in dl: 
-        # if c >=10:
-        #     break
+        if c >=10:
+            break
         bag, label, name = item
-        label_count[np.argmax(label)] += 1
-    print(label_count)
-    print(len(train_ds))
+        print(name)
+        # if name == 'RU0248_PASD_jke_PASD_20200201_195900_BIG':
+        
+            # print(bag)
+            # print(label)
+        c += 1
     #     # # print(bag.shape)
     #     # if bag.shape[1] == 1:
     #     #     print(name)
@@ -578,7 +595,7 @@ if __name__ == '__main__':
     #         o_img = Image.fromarray(o_img)
     #         o_img = o_img.convert('RGB')
     #         o_img.save(f'{output_path}/{i}_original.png')
-        # c += 1
+        
     #     break
         # else: break
         # print(data.shape)
diff --git a/datasets/custom_jpg_dataloader.py b/datasets/custom_jpg_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..722b27593ee28ebbbcf42249cd9545bdde6ed85c
--- /dev/null
+++ b/datasets/custom_jpg_dataloader.py
@@ -0,0 +1,459 @@
+'''
+ToDo: remove bag_size
+'''
+
+
+import numpy as np
+from pathlib import Path
+import torch
+from torch.utils import data
+from torch.utils.data.dataloader import DataLoader
+from tqdm import tqdm
+import torchvision.transforms as transforms
+from PIL import Image
+import cv2
+import json
+import albumentations as A
+from imgaug import augmenters as iaa
+import imgaug as ia
+from torchsampler import ImbalancedDatasetSampler
+
+
+class RangeNormalization(object):
+    def __call__(self, sample):
+        img = sample
+        return (img / 255.0 - 0.5) / 0.5
+
+class JPGMILDataloader(data.Dataset):
+    """Represents an abstract HDF5 dataset. For single H5 container! 
+    
+    Input params:
+        file_path: Path to the folder containing the dataset (one or multiple HDF5 files).
+        mode: 'train' or 'test'
+        load_data: If True, loads all the data immediately into RAM. Use this if
+            the dataset is fits into memory. Otherwise, leave this at false and 
+            the data will load lazily.
+        data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
+
+    """
+    def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=10, bag_size=1024):
+        super().__init__()
+
+        self.data_info = []
+        self.data_cache = {}
+        self.slideLabelDict = {}
+        self.files = []
+        self.data_cache_size = data_cache_size
+        self.mode = mode
+        self.file_path = file_path
+        # self.csv_path = csv_path
+        self.label_path = label_path
+        self.n_classes = n_classes
+        self.bag_size = bag_size
+        self.empty_slides = []
+        # self.label_file = label_path
+        recursive = True
+        
+        # read labels and slide_path from csv
+        with open(self.label_path, 'r') as f:
+            temp_slide_label_dict = json.load(f)[mode]
+            for (x, y) in temp_slide_label_dict:
+                x = Path(x).stem 
+
+                # x_complete_path = Path(self.file_path)/Path(x)
+                for cohort in Path(self.file_path).iterdir():
+                    x_complete_path = Path(self.file_path) / cohort / 'BLOCKS' / Path(x)
+                    if x_complete_path.is_dir():
+                        if len(list(x_complete_path.iterdir())) > 50:
+                        # print(x_complete_path)
+                            self.slideLabelDict[x] = y
+                            self.files.append(x_complete_path)
+                        else: self.empty_slides.append(x_complete_path)
+        # print(len(self.empty_slides))
+        # print(self.empty_slides)
+
+
+        for slide_dir in tqdm(self.files):
+            self._add_data_infos(str(slide_dir.resolve()), load_data)
+
+
+        self.resize_transforms = A.Compose([
+            A.SmallestMaxSize(max_size=256)
+        ])
+        sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1")
+        sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2")
+        sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3")
+        sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4")
+        sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5")
+
+        self.train_transforms = iaa.Sequential([
+            iaa.AddToHueAndSaturation(value=(-13, 13), name="MyHSV"),
+            sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")),
+            iaa.Fliplr(0.5, name="MyFlipLR"),
+            iaa.Flipud(0.5, name="MyFlipUD"),
+            sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")),
+            iaa.OneOf([
+                sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")),
+                sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")),
+                sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine"))
+            ], name="MyOneOf")
+
+        ], name="MyAug")
+
+        # self.train_transforms = A.Compose([
+        #     A.HueSaturationValue(hue_shift_limit=13, sat_shift_limit=2, val_shift_limit=0, always_apply=True, p=1.0),
+        #     # A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=0, val_shift_limit=0, always_apply=False, p=0.5),
+        #     # A.RandomGamma(),
+        #     # A.HorizontalFlip(),
+        #     # A.VerticalFlip(),
+        #     # A.RandomRotate90(),
+        #     # A.OneOf([
+        #     #     A.ElasticTransform(alpha=150, sigma=20, alpha_affine=50),
+        #     #     A.Affine(
+        #     #         scale={'x': (0.95, 1.05), 'y': (0.95, 1.05)},
+        #     #         rotate=(-45, 45),
+        #     #         shear=(-4, 4),
+        #     #         cval=8,
+        #     #         )
+        #     # ]),
+        #     A.Normalize(),
+        #     ToTensorV2(),
+        # ])
+        self.val_transforms = transforms.Compose([
+            # A.Normalize(),
+            # ToTensorV2(),
+            RangeNormalization(),
+            transforms.ToTensor(),
+
+        ])
+        self.img_transforms = transforms.Compose([    
+            transforms.RandomHorizontalFlip(p=1),
+            transforms.RandomVerticalFlip(p=1),
+            # histoTransforms.AutoRandomRotation(),
+            transforms.Lambda(lambda a: np.array(a)),
+        ]) 
+        self.hsv_transforms = transforms.Compose([
+            RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)),
+            transforms.ToTensor()
+        ])
+
+    def __getitem__(self, index):
+        # get data
+        batch, label, name = self.get_data(index)
+        out_batch = []
+        seq_img_d = self.train_transforms.to_deterministic()
+        
+        if self.mode == 'train':
+            # print(img)
+            # print(.shape)
+            for img in batch: # expects numpy 
+                img = img.numpy().astype(np.uint8)
+                # print(img.shape)
+                img = seq_img_d.augment_image(img)
+                img = self.val_transforms(img)
+                out_batch.append(img)
+
+        else:
+            for img in batch:
+                img = img.numpy().astype(np.uint8)
+                img = self.val_transforms(img)
+                out_batch.append(img)
+
+        # if len(out_batch) == 0:
+        #     # print(name)
+        #     out_batch = torch.randn(self.bag_size,3,256,256)
+        # else: 
+        out_batch = torch.stack(out_batch)
+        # print(out_batch.shape)
+        # out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch
+        # print(out_batch.shape)
+        # if out_batch.shape != torch.Size([self.bag_size, 256, 256, 3]) and out_batch.shape != torch.Size([self.bag_size, 3,256,256]):
+        #     print(name)
+        #     print(out_batch.shape)
+        # out_batch = torch.permute(out_batch, (0, 2,1,3))
+        label = torch.as_tensor(label)
+        label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+        # print(out_batch)
+        return out_batch, label, name #, name_batch
+
+    def __len__(self):
+        return len(self.data_info)
+    
+    def _add_data_infos(self, file_path, load_data):
+        wsi_name = Path(file_path).stem
+        if wsi_name in self.slideLabelDict:
+            # if wsi_name[:2] != 'RU': #skip RU because of container problems in dataset
+            label = self.slideLabelDict[wsi_name]
+            # print(wsi_name)
+            idx = -1
+            self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx})
+
+    def _load_data(self, file_path):
+        """Load data to the cache given the file
+        path and update the cache index in the
+        data_info structure.
+        """
+        wsi_batch = []
+        tile_names = []
+        # print(wsi_batch)
+        # for tile_path in Path(file_path).iterdir():
+        #     print(tile_path)
+        for tile_path in Path(file_path).iterdir():
+            # print(tile_path)
+            img = np.asarray(Image.open(tile_path)).astype(np.uint8)
+            img = torch.from_numpy(img)
+
+            # print(wsi_batch)
+            wsi_batch.append(img)
+            
+            tile_names.append(tile_path.stem)
+                
+        # if wsi_batch:
+        wsi_batch = torch.stack(wsi_batch)
+        if len(wsi_batch.shape) < 4: 
+            wsi_batch.unsqueeze(0)
+        # else: 
+        #     print('Empty Container: ', file_path)
+        #     self.empty_slides.append(file_path)
+        #     wsi_batch = torch.randn(self.bag_size,256,256,3)
+        # print(wsi_batch.shape)
+        # if wsi_batch.shape[1:] != torch.Size([3, 256, 256]) and wsi_batch.shape[1:] != torch.Size([256, 256, 3]):
+        #     print(file_path)
+        #     print(wsi_batch.shape)
+        # wsi_batch, name_batch = to_fixed_size_bag(wsi_batch, self.bag_size)
+        idx = self._add_to_cache(wsi_batch, file_path)
+        file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path)
+        self.data_info[file_idx + idx]['cache_idx'] = idx
+
+        # remove an element from data cache if size was exceeded
+        if len(self.data_cache) > self.data_cache_size:
+            # remove one item from the cache at random
+            removal_keys = list(self.data_cache)
+            removal_keys.remove(file_path)
+            self.data_cache.pop(removal_keys[0])
+            # remove invalid cache_idx
+            # self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'shape': di['shape'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
+            self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
+
+    def _add_to_cache(self, data, data_path):
+        """Adds data to the cache and returns its index. There is one cache
+        list for every file_path, containing all datasets in that file.
+        """
+        if data_path not in self.data_cache:
+            self.data_cache[data_path] = [data]
+        else:
+            self.data_cache[data_path].append(data)
+        return len(self.data_cache[data_path]) - 1
+
+    # def get_data_infos(self, type):
+    #     """Get data infos belonging to a certain type of data.
+    #     """
+    #     data_info_type = [di for di in self.data_info if di['type'] == type]
+    #     return data_info_type
+
+    def get_name(self, i):
+        # name = self.get_data_infos(type)[i]['name']
+        name = self.data_info[i]['name']
+        return name
+
+    def get_labels(self, indices):
+
+        return [self.data_info[i]['label'] for i in indices]
+        # return self.slideLabelDict.values()
+
+    def get_data(self, i):
+        """Call this function anytime you want to access a chunk of data from the
+            dataset. This will make sure that the data is loaded in case it is
+            not part of the data cache.
+            i = index
+        """
+        # fp = self.get_data_infos(type)[i]['data_path']
+        fp = self.data_info[i]['data_path']
+        if fp not in self.data_cache:
+            self._load_data(fp)
+        
+        # get new cache_idx assigned by _load_data_info
+        # cache_idx = self.get_data_infos(type)[i]['cache_idx']
+        cache_idx = self.data_info[i]['cache_idx']
+        label = self.data_info[i]['label']
+        name = self.data_info[i]['name']
+        # print(self.data_cache[fp][cache_idx])
+        return self.data_cache[fp][cache_idx], label, name
+
+
+
+class RandomHueSaturationValue(object):
+
+    def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5):
+        
+        self.hue_shift_limit = hue_shift_limit
+        self.sat_shift_limit = sat_shift_limit
+        self.val_shift_limit = val_shift_limit
+        self.p = p
+
+    def __call__(self, sample):
+    
+        img = sample #,lbl
+    
+        if np.random.random() < self.p:
+            img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32
+            h, s, v = cv2.split(img)
+            hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1)
+            hue_shift = np.uint8(hue_shift)
+            h += hue_shift
+            sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1])
+            s = cv2.add(s, sat_shift)
+            val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1])
+            v = cv2.add(v, val_shift)
+            img = cv2.merge((h, s, v))
+            img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
+        return img #, lbl
+
+def to_fixed_size_bag(bag, bag_size: int = 512):
+
+    #duplicate bag instances unitl 
+
+    # get up to bag_size elements
+    bag_idxs = torch.randperm(bag.shape[0])[:bag_size]
+    bag_samples = bag[bag_idxs]
+    # bag_sample_names = [bag_names[i] for i in bag_idxs]
+    q, r  = divmod(bag_size, bag_samples.shape[0])
+    if q > 0:
+        bag_samples = torch.cat([bag_samples]*q, 0)
+    
+    self_padded = torch.cat([bag_samples, bag_samples[:r,:, :, :]])
+
+    # zero-pad if we don't have enough samples
+    # zero_padded = torch.cat((bag_samples,
+                            # torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+
+    return self_padded, min(bag_size, len(bag))
+
+
+class RandomHueSaturationValue(object):
+
+    def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5):
+        
+        self.hue_shift_limit = hue_shift_limit
+        self.sat_shift_limit = sat_shift_limit
+        self.val_shift_limit = val_shift_limit
+        self.p = p
+
+    def __call__(self, sample):
+    
+        img = sample #,lbl
+    
+        if np.random.random() < self.p:
+            img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32
+            h, s, v = cv2.split(img)
+            hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1)
+            hue_shift = np.uint8(hue_shift)
+            h += hue_shift
+            sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1])
+            s = cv2.add(s, sat_shift)
+            val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1])
+            v = cv2.add(v, val_shift)
+            img = cv2.merge((h, s, v))
+            img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
+        return img #, lbl
+
+
+
+if __name__ == '__main__':
+    from pathlib import Path
+    import os
+
+    home = Path.cwd().parts[1]
+    train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
+    data_root = f'/{home}/ylan/data/DeepGraft/256_256um'
+    # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+    output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/augments'
+    os.makedirs(output_path, exist_ok=True)
+
+    n_classes = 2
+
+    dataset = JPGMILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes, bag_size=20)
+    # print(dataset.dataset)
+    # a = int(len(dataset)* 0.8)
+    # b = int(len(dataset) - a)
+    # train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b])
+    dl = DataLoader(dataset,  None, num_workers=1)
+    print(len(dl))
+    dl = DataLoader(dataset,  None, sampler=ImbalancedDatasetSampler(dataset), num_workers=5)
+
+    
+    
+    # data = DataLoader(dataset, batch_size=1)
+
+    # print(len(dataset))
+    # # x = 0
+    #/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/RU0248_PASD_jke_PASD_20200201_195900_BIG.hdf5
+    c = 0
+    label_count = [0] *n_classes
+    print(len(dl))
+    for item in dl: 
+        # if c >=10:
+        #     break
+        bag, label, name = item
+        # print(label)
+        label_count[torch.argmax(label)] += 1
+        # print(name)
+        # if name == 'RU0248_PASD_jke_PASD_20200201_195900_BIG':
+        
+            # print(bag)
+            # print(label)
+        c += 1
+    print(label_count)
+    #     # # print(bag.shape)
+    #     # if bag.shape[1] == 1:
+    #     #     print(name)
+    #     #     print(bag.shape)
+        # print(bag.shape)
+        
+    #     # out_dir = Path(output_path) / name
+    #     # os.makedirs(out_dir, exist_ok=True)
+
+    #     # # print(item[2])
+    #     # # print(len(item))
+    #     # # print(item[1])
+    #     # # print(data.shape)
+    #     # # data = data.squeeze()
+    #     # bag = item[0]
+    #     bag = bag.squeeze()
+    #     original = original.squeeze()
+    #     for i in range(bag.shape[0]):
+    #         img = bag[i, :, :, :]
+    #         img = img.squeeze()
+            
+    #         img = ((img-img.min())/(img.max() - img.min())) * 255
+    #         print(img)
+    #         # print(img)
+    #         img = img.numpy().astype(np.uint8).transpose(1,2,0)
+
+            
+    #         img = Image.fromarray(img)
+    #         img = img.convert('RGB')
+    #         img.save(f'{output_path}/{i}.png')
+
+
+            
+    #         o_img = original[i,:,:,:]
+    #         o_img = o_img.squeeze()
+    #         print(o_img.shape)
+    #         o_img = ((o_img-o_img.min())/(o_img.max()-o_img.min()))*255
+    #         o_img = o_img.numpy().astype(np.uint8).transpose(1,2,0)
+    #         o_img = Image.fromarray(o_img)
+    #         o_img = o_img.convert('RGB')
+    #         o_img.save(f'{output_path}/{i}_original.png')
+        
+    #     break
+        # else: break
+        # print(data.shape)
+        # print(label)
+    # a = [torch.Tensor((3,256,256))]*3
+    # b = torch.stack(a)
+    # print(b)
+    # c = to_fixed_size_bag(b, 512)
+    # print(c)
\ No newline at end of file
diff --git a/datasets/data_interface.py b/datasets/data_interface.py
index 056e6ff09bcf768c91e0f51e44640487ea4fbee1..efa104c4f3a1c0ce12ece5cfc285202c8f8f1825 100644
--- a/datasets/data_interface.py
+++ b/datasets/data_interface.py
@@ -2,15 +2,25 @@ import inspect # 查看python 类的参数和模块、函数代码
 import importlib # In order to dynamically import the library
 from typing import Optional
 import pytorch_lightning as pl
+from pytorch_lightning.loops.base import Loop
+from pytorch_lightning.loops.fit_loop import FitLoop
+
 from torch.utils.data import random_split, DataLoader
+from torch.utils.data.dataset import Dataset, Subset
 from torchvision.datasets import MNIST
 from torchvision import transforms
 from .camel_dataloader import FeatureBagLoader
 from .custom_dataloader import HDF5MILDataloader
+from .custom_jpg_dataloader import JPGMILDataloader
 from pathlib import Path
 from transformers import AutoFeatureExtractor
 from torchsampler import ImbalancedDatasetSampler
 
+from abc import ABC, abstractclassmethod, abstractmethod
+from sklearn.model_selection import KFold
+
+
+
 class DataInterface(pl.LightningDataModule):
 
     def __init__(self, train_batch_size=64, train_num_workers=8, test_batch_size=1, test_num_workers=1,dataset_name=None, **kwargs):
@@ -109,7 +119,7 @@ class DataInterface(pl.LightningDataModule):
 
 class MILDataModule(pl.LightningDataModule):
 
-    def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, backbone=None, *args, **kwargs):
+    def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=50, n_classes=2, cache: bool=True, *args, **kwargs):
         super().__init__()
         self.data_root = data_root
         self.label_path = label_path
@@ -124,27 +134,29 @@ class MILDataModule(pl.LightningDataModule):
         self.num_bags_test = 50
         self.seed = 1
 
-        self.backbone = backbone
         self.cache = True
         self.fe_transform = None
+        
 
 
     def setup(self, stage: Optional[str] = None) -> None:
         home = Path.cwd().parts[1]
 
         if stage in (None, 'fit'):
-            dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes, backbone=self.backbone)
+            dataset = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes)
             a = int(len(dataset)* 0.8)
             b = int(len(dataset) - a)
             self.train_data, self.valid_data = random_split(dataset, [a, b])
 
         if stage in (None, 'test'):
-            self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, backbone=self.backbone)
+            self.test_data = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, data_cache_size=1)
 
         return super().setup(stage=stage)
 
+        
+
     def train_dataloader(self) -> DataLoader:
-        return DataLoader(self.train_data,  self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+        return DataLoader(self.train_data,  batch_size = self.batch_size, sampler=ImbalancedDatasetSampler(self.train_data), num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True, 
         #sampler=ImbalancedDatasetSampler(self.train_data)
     def val_dataloader(self) -> DataLoader:
         return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers)
@@ -187,13 +199,92 @@ class DataModule(pl.LightningDataModule):
         if stage in (None, 'test'):
             self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, backbone=self.backbone)
 
+
         return super().setup(stage=stage)
 
     def train_dataloader(self) -> DataLoader:
-        return DataLoader(self.train_data,  self.batch_size,  num_workers=self.num_workers, shuffle=False,) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+        return DataLoader(self.train_data,  self.batch_size, shuffle=False,) #batch_transforms=self.transform, pseudo_batch_dim=True, 
         #sampler=ImbalancedDatasetSampler(self.train_data),
     def val_dataloader(self) -> DataLoader:
-        return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers)
+        return DataLoader(self.valid_data, batch_size = self.batch_size)
+    
+    def test_dataloader(self) -> DataLoader:
+        return DataLoader(self.test_data, batch_size = self.batch_size) #, num_workers=self.num_workers
+
+
+class BaseKFoldDataModule(pl.LightningDataModule, ABC):
+    @abstractmethod
+    def setup_folds(self, num_folds: int) -> None:
+        pass
+
+    @abstractmethod
+    def setup_fold_index(self, fold_index: int) -> None:
+        pass
+
+class CrossVal_MILDataModule(BaseKFoldDataModule):
+
+    def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, backbone=None, *args, **kwargs):
+        super().__init__()
+        self.data_root = data_root
+        self.label_path = label_path
+        self.batch_size = batch_size
+        self.num_workers = num_workers
+        self.image_size = 384
+        self.n_classes = n_classes
+        self.target_number = 9
+        self.mean_bag_length = 10
+        self.var_bag_length = 2
+        self.num_bags_train = 200
+        self.num_bags_test = 50
+        self.seed = 1
+
+        self.backbone = backbone
+        self.cache = True
+        self.fe_transform = None
+
+        # train_dataset: Optional[Dataset] = None
+        # test_dataset: Optional[Dataset] = None
+        # train_fold: Optional[Dataset] = None
+        # val_fold: Optional[Dataset] = None
+        self.train_data : Optional[Dataset] = None
+        self.test_data : Optional[Dataset] = None
+        self.train_fold : Optional[Dataset] = None
+        self.val_fold : Optional[Dataset] = None
+
+    def setup(self, stage: Optional[str] = None) -> None:
+        home = Path.cwd().parts[1]
+
+        # if stage in (None, 'fit'):
+        dataset = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes)
+        # a = int(len(dataset)* 0.8)
+        # b = int(len(dataset) - a)
+        # self.train_data, self.val_data = random_split(dataset, [a, b])
+        self.train_data = dataset
+
+        # if stage in (None, 'test'):,
+        self.test_data = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes)
+
+        # return super().setup(stage=stage)
+
+    def setup_folds(self, num_folds: int) -> None:
+        self.num_folds = num_folds
+        self.splits = [split for split in KFold(num_folds).split(range(len(self.train_data)))]
+
+    def setup_fold_index(self, fold_index: int) -> None:
+        train_indices, val_indices = self.splits[fold_index]
+        self.train_fold = Subset(self.train_data, train_indices)
+        self.val_fold = Subset(self.train_data, val_indices)
+
+
+    def train_dataloader(self) -> DataLoader:
+        return DataLoader(self.train_fold,  self.batch_size, sampler=ImbalancedDatasetSampler(self.train_fold), num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+        # return DataLoader(self.train_fold,  self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+        #sampler=ImbalancedDatasetSampler(self.train_data)
+    def val_dataloader(self) -> DataLoader:
+        return DataLoader(self.val_fold, batch_size = self.batch_size, num_workers=self.num_workers)
     
     def test_dataloader(self) -> DataLoader:
-        return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers)
\ No newline at end of file
+        return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers)
+
+
+
diff --git a/models/AttMIL.py b/models/AttMIL.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1e20eb3e5d465eab0008c5121b21f1fd9394605
--- /dev/null
+++ b/models/AttMIL.py
@@ -0,0 +1,79 @@
+import os
+import logging
+from argparse import ArgumentParser
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models as models
+
+import pytorch_lightning as pl
+from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
+
+
+class AttMIL(nn.Module): #gated attention
+    def __init__(self, n_classes, features=512):
+        super(AttMIL, self).__init__()
+        self.L = features
+        self.D = 128
+        self.K = 1
+        self.n_classes = n_classes
+
+        # resnet50 = models.resnet50(pretrained=True)    
+        # modules = list(resnet50.children())[:-3]
+
+        # self.resnet_extractor = nn.Sequential(
+        #     *modules,
+        #     nn.AdaptiveAvgPool2d(1),
+        #     View((-1, 1024)),
+        #     nn.Linear(1024, self.L)
+        # )
+
+        # self.feature_extractor1 = nn.Sequential(
+        #     nn.Conv2d(3, 20, kernel_size=5),
+        #     nn.ReLU(),
+        #     nn.MaxPool2d(2, stride=2),
+        #     nn.Conv2d(20, 50, kernel_size=5),
+        #     nn.ReLU(),
+        #     nn.MaxPool2d(2, stride=2),
+
+        #     # View((-1, 50 * 4 * 4)),
+        #     # nn.Linear(50 * 4 * 4, self.L),
+        #     # nn.ReLU(),
+        # )
+
+        # self.feature_extractor_part2 = nn.Sequential(
+        #     nn.Linear(50 * 4 * 4, self.L),
+        #     nn.ReLU(),
+        # )
+
+        self.attention_V = nn.Sequential(
+            nn.Linear(self.L, self.D),
+            nn.Tanh()
+        )
+
+        self.attention_U = nn.Sequential(
+            nn.Linear(self.L, self.D),
+            nn.Sigmoid()
+        )
+
+        self.attention_weights = nn.Linear(self.D, self.K)
+
+        self.classifier = nn.Sequential(
+            nn.Linear(self.L * self.K, self.n_classes),
+        )    
+
+    def forward(self, x):
+        # H = kwargs['data'].float().squeeze(0)
+        H = x.float().squeeze(0)
+        A_V = self.attention_V(H)  # NxD
+        A_U = self.attention_U(H)  # NxD
+        A = self.attention_weights(A_V * A_U) # element wise multiplication # NxK
+        out_A = A
+        A = torch.transpose(A, 1, 0)  # KxN
+        A = F.softmax(A, dim=1)  # softmax over N
+        M = torch.mm(A, H)  # KxL
+        logits = self.classifier(M)
+       
+        return logits
\ No newline at end of file
diff --git a/models/TransMIL.py b/models/TransMIL.py
index 69089de0bd934ea38b2854f513295b21cbb73a03..ca2a1fbe8240fd1983b246cbbfbdad187c48b214 100755
--- a/models/TransMIL.py
+++ b/models/TransMIL.py
@@ -44,23 +44,23 @@ class PPEG(nn.Module):
 
 
 class TransMIL(nn.Module):
-    def __init__(self, n_classes):
+    def __init__(self, n_classes, in_features, out_features=384):
         super(TransMIL, self).__init__()
-        self.pos_layer = PPEG(dim=512)
-        self._fc1 = nn.Sequential(nn.Linear(512, 512), nn.ReLU())
+        self.pos_layer = PPEG(dim=out_features)
+        self._fc1 = nn.Sequential(nn.Linear(in_features, out_features), nn.GELU())
         # self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU())
-        self.cls_token = nn.Parameter(torch.randn(1, 1, 512))
+        self.cls_token = nn.Parameter(torch.randn(1, 1, out_features))
         self.n_classes = n_classes
-        self.layer1 = TransLayer(dim=512)
-        self.layer2 = TransLayer(dim=512)
-        self.norm = nn.LayerNorm(512)
-        self._fc2 = nn.Linear(512, self.n_classes)
+        self.layer1 = TransLayer(dim=out_features)
+        self.layer2 = TransLayer(dim=out_features)
+        self.norm = nn.LayerNorm(out_features)
+        self._fc2 = nn.Linear(out_features, self.n_classes)
 
 
-    def forward(self, **kwargs): #, **kwargs
+    def forward(self, x): #, **kwargs
 
-        h = kwargs['data'].float() #[B, n, 1024]
-        # h = self._fc1(h) #[B, n, 512]
+        h = x.float() #[B, n, 1024]
+        h = self._fc1(h) #[B, n, 512]
         
         #---->pad
         H = h.shape[1]
@@ -83,22 +83,22 @@ class TransMIL(nn.Module):
         h = self.layer2(h) #[B, N, 512]
 
         #---->cls_token
+        print(h.shape) #[1, 1025, 512] 1025 = cls_token + 1024
+
+        # tokens = h
         h = self.norm(h)[:,0]
 
         #---->predict
         logits = self._fc2(h) #[B, n_classes]
-        Y_hat = torch.argmax(logits, dim=1)
-        Y_prob = F.softmax(logits, dim = 1)
-        results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat}
-        return results_dict
+        return logits
 
 if __name__ == "__main__":
     data = torch.randn((1, 6000, 512)).cuda()
-    model = TransMIL(n_classes=2).cuda()
+    model = TransMIL(n_classes=2, in_features=512).cuda()
     print(model.eval())
-    results_dict = model(data = data)
+    results_dict = model(data)
     print(results_dict)
-    logits = results_dict['logits']
-    Y_prob = results_dict['Y_prob']
-    Y_hat = results_dict['Y_hat']
+    # logits = results_dict['logits']
+    # Y_prob = results_dict['Y_prob']
+    # Y_hat = results_dict['Y_hat']
     # print(F.sigmoid(logits))
diff --git a/models/__pycache__/AttMIL.cpython-39.pyc b/models/__pycache__/AttMIL.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ee3af4e0d05559e5bc7a4d8ed62481a3254c9f1
Binary files /dev/null and b/models/__pycache__/AttMIL.cpython-39.pyc differ
diff --git a/models/__pycache__/TransMIL.cpython-39.pyc b/models/__pycache__/TransMIL.cpython-39.pyc
index cb0eb6fd5085feda9604ec0f24420d85eb1d939d..a329e23346d376b150c7ba792b8cd3593a128d95 100644
Binary files a/models/__pycache__/TransMIL.cpython-39.pyc and b/models/__pycache__/TransMIL.cpython-39.pyc differ
diff --git a/models/__pycache__/model_interface.cpython-39.pyc b/models/__pycache__/model_interface.cpython-39.pyc
index 0bf337675c76de4097e7f7723b99de0120f9594c..e9d22d7ddccbb2a3d2ab59b98e066f9ac5d42285 100644
Binary files a/models/__pycache__/model_interface.cpython-39.pyc and b/models/__pycache__/model_interface.cpython-39.pyc differ
diff --git a/models/model_interface.py b/models/model_interface.py
index 60b5cc73d6bec977216cdde6bb8bcee3fd267034..b3a561eba25656b93f4e4663c88e796022a556d0 100755
--- a/models/model_interface.py
+++ b/models/model_interface.py
@@ -7,6 +7,8 @@ import pandas as pd
 import seaborn as sns
 from pathlib import Path
 from matplotlib import pyplot as plt
+import cv2
+from PIL import Image
 
 #---->
 from MyOptimizer import create_optimizer
@@ -20,7 +22,10 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torchmetrics
+from torchmetrics.functional import stat_scores
 from torch import optim as optim
+# from sklearn.metrics import roc_curve, auc, roc_curve_score
+
 
 #---->
 import pytorch_lightning as pl
@@ -29,6 +34,10 @@ from torchvision import models
 from torchvision.models import resnet
 from transformers import AutoFeatureExtractor, ViTModel
 
+from pytorch_grad_cam import GradCAM, EigenGradCAM
+from pytorch_grad_cam.utils.image import show_cam_on_image
+from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
+
 from captum.attr import LayerGradCam
 
 class ModelInterface(pl.LightningModule):
@@ -41,11 +50,20 @@ class ModelInterface(pl.LightningModule):
         self.loss = create_loss(loss)
         # self.asl = AsymmetricLossSingleLabel()
         # self.loss = LabelSmoothingCrossEntropy(smoothing=0.1)
-        
         # self.loss = 
+        # print(self.model)
+        
+        
+        # self.ecam = EigenGradCAM(model = self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform)
         self.optimizer = optimizer
         self.n_classes = model.n_classes
-        self.log_path = kargs['log']
+        print(self.n_classes)
+        self.save_path = kargs['log']
+        if Path(self.save_path).parts[3] == 'tcmr':
+            temp = list(Path(self.save_path).parts)
+            # print(temp)
+            temp[3] = 'tcmr_viral'
+            self.save_path = '/'.join(temp)
 
         #---->acc
         self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
@@ -53,6 +71,7 @@ class ModelInterface(pl.LightningModule):
         #---->Metrics
         if self.n_classes > 2: 
             self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted')
+            
             metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes,
                                                                            average='micro'),
                                                      torchmetrics.CohenKappa(num_classes = self.n_classes),
@@ -67,6 +86,7 @@ class ModelInterface(pl.LightningModule):
                                                                             
         else : 
             self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'weighted')
+
             metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2,
                                                                            average = 'micro'),
                                                      torchmetrics.CohenKappa(num_classes = 2),
@@ -76,6 +96,8 @@ class ModelInterface(pl.LightningModule):
                                                                          num_classes = 2),
                                                      torchmetrics.Precision(average = 'macro',
                                                                             num_classes = 2)])
+        self.PRC = torchmetrics.PrecisionRecallCurve(num_classes = self.n_classes)
+        # self.pr_curve = torchmetrics.BinnedPrecisionRecallCurve(num_classes = self.n_classes, thresholds=10)
         self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes)                                                                    
         self.valid_metrics = metrics.clone(prefix = 'val_')
         self.test_metrics = metrics.clone(prefix = 'test_')
@@ -146,28 +168,40 @@ class ModelInterface(pl.LightningModule):
                 nn.Linear(1024, self.out_features),
                 nn.ReLU(),
             )
+        # print(self.model_ft[0].features[-1])
+        # print(self.model_ft)
+        if model.name == 'TransMIL':
+            target_layers = [self.model.layer2.norm] # 32x32
+            # target_layers = [self.model_ft[0].features[-1]] # 32x32
+            self.cam = GradCAM(model=self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform) #, reshape_transform=self.reshape_transform
+            # self.cam_ft = GradCAM(model=self.model, target_layers = target_layers_ft, use_cuda=True) #, reshape_transform=self.reshape_transform
+        else:
+            target_layers = [self.model.attention_weights]
+            self.cam = GradCAM(model = self.model, target_layers = target_layers, use_cuda=True)
+
+    def forward(self, x):
+        
+        feats = self.model_ft(x).unsqueeze(0)
+        return self.model(feats)
+
+    def step(self, input):
+
+        input = input.squeeze(0).float()
+        logits = self(input) 
+
+        Y_hat = torch.argmax(logits, dim=1)
+        Y_prob = F.softmax(logits, dim=1)
+
+        return logits, Y_prob, Y_hat
 
     def training_step(self, batch, batch_idx):
         #---->inference
         
-        data, label, _ = batch
+
+        input, label, _= batch
         label = label.float()
-        data = data.squeeze(0).float()
-        # print(data)
-        # print(data.shape)
-        if self.backbone == 'dino':
-            features = self.model_ft(**data)
-            features = features.last_hidden_state
-        else:
-            features = self.model_ft(data)
-        features = features.unsqueeze(0)
-        # print(features.shape)
-        # features = features.squeeze()
-        results_dict = self.model(data=features) 
-        # results_dict = self.model(data=data, label=label)
-        logits = results_dict['logits']
-        Y_prob = results_dict['Y_prob']
-        Y_hat = results_dict['Y_hat']
+        
+        logits, Y_prob, Y_hat = self.step(input) 
 
         #---->loss
         loss = self.loss(logits, label)
@@ -183,6 +217,14 @@ class ModelInterface(pl.LightningModule):
             # Y = int(label[0])
         self.data[Y]["count"] += 1
         self.data[Y]["correct"] += (int(Y_hat) == Y)
+        self.log('loss', loss, prog_bar=True, on_epoch=True, logger=True)
+
+        if self.current_epoch % 10 == 0:
+
+            grid = torchvision.utils.make_grid(images)
+        # log input images 
+        # self.loggers[0].experiment.add_figure(f'{stage}/input', , self.current_epoch)
+
 
         return {'loss': loss, 'Y_prob': Y_prob, 'Y_hat': Y_hat, 'label': label} 
 
@@ -212,18 +254,10 @@ class ModelInterface(pl.LightningModule):
 
     def validation_step(self, batch, batch_idx):
 
-        data, label, _ = batch
-
+        input, label, _ = batch
         label = label.float()
-        data = data.squeeze(0).float()
-        features = self.model_ft(data)
-        features = features.unsqueeze(0)
-
-        results_dict = self.model(data=features)
-        logits = results_dict['logits']
-        Y_prob = results_dict['Y_prob']
-        Y_hat = results_dict['Y_hat']
-
+        
+        logits, Y_prob, Y_hat = self.step(input) 
 
         #---->acc log
         # Y = int(label[0][1])
@@ -237,18 +271,23 @@ class ModelInterface(pl.LightningModule):
 
     def validation_epoch_end(self, val_step_outputs):
         logits = torch.cat([x['logits'] for x in val_step_outputs], dim = 0)
-        # probs = torch.cat([x['Y_prob'] for x in val_step_outputs], dim = 0)
         probs = torch.cat([x['Y_prob'] for x in val_step_outputs])
         max_probs = torch.stack([x['Y_hat'] for x in val_step_outputs])
-        # target = torch.stack([x['label'] for x in val_step_outputs], dim = 0)
         target = torch.cat([x['label'] for x in val_step_outputs])
         target = torch.argmax(target, dim=1)
         #---->
         # logits = logits.long()
         # target = target.squeeze().long()
         # logits = logits.squeeze(0)
+        if len(target.unique()) != 1:
+            self.log('val_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True)
+        else:    
+            self.log('val_auc', 0.0, prog_bar=True, on_epoch=True, logger=True)
+
+        
+
         self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True)
-        self.log('val_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True)
+        
 
         # print(max_probs.squeeze(0).shape)
         # print(target.shape)
@@ -276,24 +315,61 @@ class ModelInterface(pl.LightningModule):
             random.seed(self.count*50)
 
     def test_step(self, batch, batch_idx):
-
-        data, label, _ = batch
+        torch.set_grad_enabled(True)
+        data, label, name = batch
         label = label.float()
+        # logits, Y_prob, Y_hat = self.step(data) 
+        # print(data.shape)
         data = data.squeeze(0).float()
-        features = self.model_ft(data)
-        features = features.unsqueeze(0)
+        logits = self(data).detach() 
+
+        Y = torch.argmax(label)
+        Y_hat = torch.argmax(logits, dim=1)
+        Y_prob = F.softmax(logits, dim = 1)
+        
+        #----> Get Topk tiles 
+
+        target = [ClassifierOutputTarget(Y)]
+
+        data_ft = self.model_ft(data).unsqueeze(0).float()
+        # data_ft = self.model_ft(data).unsqueeze(0).float()
+        # print(data_ft.shape)
+        # print(target)
+        grayscale_cam = self.cam(input_tensor=data_ft, targets=target)
+        # grayscale_ecam = self.ecam(input_tensor=data_ft, targets=target)
+
+        # print(grayscale_cam)
 
-        results_dict = self.model(data=features, label=label)
-        logits = results_dict['logits']
-        Y_prob = results_dict['Y_prob']
-        Y_hat = results_dict['Y_hat']
+        summed = torch.mean(torch.Tensor(grayscale_cam), dim=2)
+        print(summed)
+        print(summed.shape)
+        topk_tiles, topk_indices = torch.topk(summed.squeeze(0), 5, dim=0)
+        topk_data = data[topk_indices].detach()
+        
+        # target_ft = 
+        # grayscale_cam_ft = self.cam_ft(input_tensor=data, )
+        # for i in range(data.shape[0]):
+            
+            # vis_img = data[i, :, :, :].cpu().numpy()
+            # vis_img = np.transpose(vis_img, (1,2,0))
+            # print(vis_img.shape)
+            # cam_img = grayscale_cam.squeeze(0)
+        # cam_img = self.reshape_transform(grayscale_cam)
+
+        # print(cam_img.shape)
+            
+            # visualization = show_cam_on_image(vis_img, cam_img, use_rgb=True)
+            # visualization = ((visualization/visualization.max())*255.0).astype(np.uint8)
+            # print(visualization)
+        # cv2.imwrite(f'{test_path}/{Y}/{name}/gradcam.jpg', cam_img)
 
         #---->acc log
         Y = torch.argmax(label)
         self.data[Y]["count"] += 1
         self.data[Y]["correct"] += (Y_hat.item() == Y)
 
-        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label}
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': name, 'topk_data': topk_data} #
+        # return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': name} #, 'topk_data': topk_data
 
     def test_epoch_end(self, output_results):
         probs = torch.cat([x['Y_prob'] for x in output_results])
@@ -301,7 +377,8 @@ class ModelInterface(pl.LightningModule):
         # target = torch.stack([x['label'] for x in output_results], dim = 0)
         target = torch.cat([x['label'] for x in output_results])
         target = torch.argmax(target, dim=1)
-        
+        patients = [x['name'] for x in output_results]
+        topk_tiles = [x['topk_data'] for x in output_results]
         #---->
         auc = self.AUROC(probs, target.squeeze())
         metrics = self.test_metrics(max_probs.squeeze() , target)
@@ -312,9 +389,41 @@ class ModelInterface(pl.LightningModule):
 
         # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True)
 
-        # print(max_probs.squeeze(0).shape)
-        # print(target.shape)
-        # self.log_dict(metrics, logger = True)
+        #---->get highest scoring patients for each class
+        test_path = Path(self.save_path) / 'most_predictive'
+        topk, topk_indices = torch.topk(probs.squeeze(0), 5, dim=0)
+        for n in range(self.n_classes):
+            print('class: ', n)
+            topk_patients = [patients[i[n]] for i in topk_indices]
+            topk_patient_tiles = [topk_tiles[i[n]] for i in topk_indices]
+            for x, p, t in zip(topk, topk_patients, topk_patient_tiles):
+                print(p, x[n])
+                patient = p[0]
+                outpath = test_path / str(n) / patient 
+                outpath.mkdir(parents=True, exist_ok=True)
+                for i in range(len(t)):
+                    tile = t[i]
+                    tile = tile.cpu().numpy().transpose(1,2,0)
+                    tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255
+                    tile = tile.astype(np.uint8)
+                    img = Image.fromarray(tile)
+                    
+                    img.save(f'{test_path}/{n}/{patient}/{i}_gradcam.jpg')
+
+            
+            
+        #----->visualize top predictive tiles
+        
+        
+
+        
+                # img = img.squeeze(0).cpu().numpy()
+                # img = np.transpose(img, (1,2,0))
+                # # print(img)
+                # # print(grayscale_cam.shape)
+                # visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True)
+
+
         for keys, values in metrics.items():
             print(f'{keys} = {values}')
             metrics[keys] = values.cpu().numpy()
@@ -329,16 +438,35 @@ class ModelInterface(pl.LightningModule):
             print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
         self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
 
+        #---->plot auroc curve
+        # stats = stat_scores(probs, target, reduce='macro', num_classes=self.n_classes)
+        # fpr = {}
+        # tpr = {}
+        # for n in self.n_classes: 
+
+        # fpr, tpr, thresh = roc_curve(target.cpu().numpy(), probs.cpu().numpy())
+        #[tp, fp, tn, fn, tp+fn]
+
+
         self.log_confusion_matrix(probs, target, stage='test')
         #---->
         result = pd.DataFrame([metrics])
-        result.to_csv(self.log_path / 'result.csv')
+        result.to_csv(Path(self.save_path) / f'test_result.csv', mode='a', header=not Path(self.save_path).exists())
+
+        # with open(f'{self.save_path}/test_metrics.txt', 'a') as f:
+
+        #     f.write([metrics])
 
     def configure_optimizers(self):
         # optimizer_ft = optim.Adam(self.model_ft.parameters(), lr=self.optimizer.lr*0.1)
         optimizer = create_optimizer(self.optimizer, self.model)
         return optimizer     
 
+    def reshape_transform(self, tensor, h=32, w=32):
+        result = tensor[:, 1:, :].reshape(tensor.size(0), h, w, tensor.size(2))
+        result = result.transpose(2,3).transpose(1,2)
+        # print(result.shape)
+        return result
 
     def load_model(self):
         name = self.hparams.model.name
@@ -372,18 +500,33 @@ class ModelInterface(pl.LightningModule):
         args1.update(other_args)
         return Model(**args1)
 
+    def log_image(self, tensor, stage, name):
+        
+        tile = tile.cpu().numpy().transpose(1,2,0)
+        tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255
+        tile = tile.astype(np.uint8)
+        img = Image.fromarray(tile)
+        self.loggers[0].experiment.add_figure(f'{stage}/{name}', img, self.current_epoch)
+
+
     def log_confusion_matrix(self, max_probs, target, stage):
         confmat = self.confusion_matrix(max_probs.squeeze(), target)
+        print(confmat)
         df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
-        plt.figure()
+        # plt.figure()
         fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
         # plt.close(fig_)
-        # plt.savefig(f'{self.log_path}/cm_e{self.current_epoch}')
-        self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+        # plt.savefig(f'{self.save_path}/cm_e{self.current_epoch}')
+        
 
-        if stage == 'test':
-            plt.savefig(f'{self.log_path}/cm_test')
-        plt.close(fig_)
+        if stage == 'train':
+            # print(self.save_path)
+            # plt.savefig(f'{self.save_path}/cm_test')
+
+            self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+        else:
+            fig_.savefig(f'{self.save_path}/cm_test.png', dpi=400)
+        # plt.close(fig_)
         # self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch)
 
 class View(nn.Module):
diff --git a/test_visualize.py b/test_visualize.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ef56d3fec1a0c3d926428e95b051c13bdb1da96
--- /dev/null
+++ b/test_visualize.py
@@ -0,0 +1,148 @@
+import argparse
+from pathlib import Path
+import numpy as np
+import glob
+
+from sklearn.model_selection import KFold
+
+from datasets.data_interface import DataInterface, MILDataModule, CrossVal_MILDataModule
+from models.model_interface import ModelInterface
+import models.vision_transformer as vits
+from utils.utils import *
+
+# pytorch_lightning
+import pytorch_lightning as pl
+from pytorch_lightning import Trainer
+import torch
+from train_loop import KFoldLoop
+
+#--->Setting parameters
+def make_parse():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--stage', default='train', type=str)
+    parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str)
+    parser.add_argument('--version', default=0,type=int)
+    parser.add_argument('--epoch', default='0',type=str)
+    parser.add_argument('--gpus', default = 2, type=int)
+    parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str)
+    parser.add_argument('--fold', default = 0)
+    parser.add_argument('--bag_size', default = 1024, type=int)
+
+    args = parser.parse_args()
+    return args
+
+#---->main
+def main(cfg):
+
+    torch.set_num_threads(16)
+
+    #---->Initialize seed
+    pl.seed_everything(cfg.General.seed)
+
+    #---->load loggers
+    # cfg.load_loggers = load_loggers(cfg)
+
+    # print(cfg.load_loggers)
+    # save_path = Path(cfg.load_loggers[0].log_dir) 
+
+    #---->load callbacks
+    # cfg.callbacks = load_callbacks(cfg, save_path)
+
+    home = Path.cwd().parts[1]
+    DataInterface_dict = {
+                'data_root': cfg.Data.data_dir,
+                'label_path': cfg.Data.label_file,
+                'batch_size': cfg.Data.train_dataloader.batch_size,
+                'num_workers': cfg.Data.train_dataloader.num_workers,
+                'n_classes': cfg.Model.n_classes,
+                'backbone': cfg.Model.backbone,
+                'bag_size': cfg.Data.bag_size,
+                }
+
+    dm = MILDataModule(**DataInterface_dict)
+    
+
+    #---->Define Model
+    ModelInterface_dict = {'model': cfg.Model,
+                            'loss': cfg.Loss,
+                            'optimizer': cfg.Optimizer,
+                            'data': cfg.Data,
+                            'log': cfg.log_path,
+                            'backbone': cfg.Model.backbone,
+                            }
+    model = ModelInterface(**ModelInterface_dict)
+    
+    #---->Instantiate Trainer
+    trainer = Trainer(
+        num_sanity_val_steps=0, 
+        # logger=cfg.load_loggers,
+        # callbacks=cfg.callbacks,
+        max_epochs= cfg.General.epochs,
+        min_epochs = 200,
+        gpus=cfg.General.gpus,
+        # gpus = [0,2],
+        # strategy='ddp',
+        amp_backend='native',
+        # amp_level=cfg.General.amp_level,  
+        precision=cfg.General.precision,  
+        accumulate_grad_batches=cfg.General.grad_acc,
+        # fast_dev_run = True,
+        
+        # deterministic=True,
+        check_val_every_n_epoch=10,
+    )
+
+    #---->train or test
+    log_path = cfg.log_path
+    # print(log_path)
+    # log_path = Path('lightning_logs/2/checkpoints')
+    model_paths = list(log_path.glob('*.ckpt'))
+
+    if cfg.epoch == 'last':
+        model_paths = [str(model_path) for model_path in model_paths if f'last' in str(model_path)]
+    else:
+        model_paths = [str(model_path) for model_path in model_paths if f'epoch={cfg.epoch}' in str(model_path)]
+
+    # model_paths = [str(model_path) for model_path in model_paths if f'epoch={cfg.epoch}' in str(model_path)]
+    # model_paths = [f'lightning_logs/0/.ckpt']
+    # model_paths = [f'{log_path}/last.ckpt']
+    if not model_paths: 
+        print('No Checkpoints vailable!')
+    for path in model_paths:
+        # with open(f'{log_path}/test_metrics.txt', 'w') as f:
+        #     f.write(str(path) + '\n')
+        print(path)
+        new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg)
+        trainer.test(model=new_model, datamodule=dm)
+    
+    # Top 5 scoring patches for patient
+    # GradCam
+
+
+if __name__ == '__main__':
+
+    args = make_parse()
+    cfg = read_yaml(args.config)
+
+    #---->update
+    cfg.config = args.config
+    cfg.General.gpus = [args.gpus]
+    cfg.General.server = args.stage
+    cfg.Data.fold = args.fold
+    cfg.Loss.base_loss = args.loss
+    cfg.Data.bag_size = args.bag_size
+    cfg.version = args.version
+    cfg.epoch = args.epoch
+
+    log_path = Path(cfg.General.log_path) / str(Path(cfg.config).parent)
+    Path(cfg.General.log_path).mkdir(exist_ok=True, parents=True)
+    log_name =  f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
+    task = '_'.join(Path(cfg.config).name[:-5].split('_')[2:])
+    # task = Path(cfg.config).name[:-5].split('_')[2:][0]
+    cfg.log_path = log_path / f'{cfg.Model.name}' / task / log_name / 'lightning_logs' / f'version_{cfg.version}' / 'checkpoints'
+    
+    
+
+    #---->main
+    main(cfg)
+ 
\ No newline at end of file
diff --git a/train.py b/train.py
index 036d5ed183d4c8ee047ff3413bfb3d4730154b19..5e30394352f771017d43e31cc8585ed4ef9aeb07 100644
--- a/train.py
+++ b/train.py
@@ -5,7 +5,7 @@ import glob
 
 from sklearn.model_selection import KFold
 
-from datasets.data_interface import DataInterface, MILDataModule
+from datasets.data_interface import DataInterface, MILDataModule, CrossVal_MILDataModule
 from models.model_interface import ModelInterface
 import models.vision_transformer as vits
 from utils.utils import *
@@ -13,18 +13,23 @@ from utils.utils import *
 # pytorch_lightning
 import pytorch_lightning as pl
 from pytorch_lightning import Trainer
-from pytorch_lightning.loops import KFoldLoop
 import torch
+from train_loop import KFoldLoop
 
 #--->Setting parameters
 def make_parse():
     parser = argparse.ArgumentParser()
     parser.add_argument('--stage', default='train', type=str)
     parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str)
+    parser.add_argument('--version', default=2,type=int)
     parser.add_argument('--gpus', default = 2, type=int)
     parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str)
     parser.add_argument('--fold', default = 0)
     parser.add_argument('--bag_size', default = 1024, type=int)
+    parser.add_argument('--resume_training', action='store_true')
+    # parser.add_argument('--ckpt_path', default = , type=str)
+    
+
     args = parser.parse_args()
     return args
 
@@ -39,9 +44,10 @@ def main(cfg):
     #---->load loggers
     cfg.load_loggers = load_loggers(cfg)
     # print(cfg.load_loggers)
+    save_path = Path(cfg.load_loggers[0].log_dir) 
 
     #---->load callbacks
-    cfg.callbacks = load_callbacks(cfg)
+    cfg.callbacks = load_callbacks(cfg, save_path)
 
     #---->Define Data 
     # DataInterface_dict = {'train_batch_size': cfg.Data.train_dataloader.batch_size,
@@ -58,11 +64,12 @@ def main(cfg):
                 'batch_size': cfg.Data.train_dataloader.batch_size,
                 'num_workers': cfg.Data.train_dataloader.num_workers,
                 'n_classes': cfg.Model.n_classes,
-                'backbone': cfg.Model.backbone,
                 'bag_size': cfg.Data.bag_size,
                 }
 
-    dm = MILDataModule(**DataInterface_dict)
+    if cfg.Data.cross_val:
+        dm = CrossVal_MILDataModule(**DataInterface_dict)
+    else: dm = MILDataModule(**DataInterface_dict)
     
 
     #---->Define Model
@@ -82,9 +89,9 @@ def main(cfg):
         callbacks=cfg.callbacks,
         max_epochs= cfg.General.epochs,
         min_epochs = 200,
-        # gpus=cfg.General.gpus,
-        gpus = [2,3],
-        strategy='ddp',
+        gpus=cfg.General.gpus,
+        # gpus = [0,2],
+        # strategy='ddp',
         amp_backend='native',
         # amp_level=cfg.General.amp_level,  
         precision=cfg.General.precision,  
@@ -96,12 +103,31 @@ def main(cfg):
     )
 
     #---->train or test
+    if cfg.resume_training:
+        last_ckpt = log_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.version}' / 'last.ckpt'
+        trainer.fit(model = model, datamodule = dm, ckpt_path=last_ckpt)
+
     if cfg.General.server == 'train':
-        trainer.fit_loop = KFoldLoop(3, trainer.fit_loop, )
-        trainer.fit(model = model, datamodule = dm)
+
+        # k-fold cross validation loop
+        if cfg.Data.cross_val: 
+            internal_fit_loop = trainer.fit_loop
+            trainer.fit_loop = KFoldLoop(cfg.Data.nfold, export_path = cfg.log_path, **ModelInterface_dict)
+            trainer.fit_loop.connect(internal_fit_loop)
+            trainer.fit(model, dm)
+        else:
+            trainer.fit(model = model, datamodule = dm)
     else:
-        model_paths = list(cfg.log_path.glob('*.ckpt'))
+        log_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.version}' 
+
+        test_path = Path(log_path) / 'test'
+        for n in range(cfg.Model.n_classes):
+            n_output_path = test_path / str(n)
+            n_output_path.mkdir(parents=True, exist_ok=True)
+        # print(cfg.log_path)
+        model_paths = list(log_path.glob('*.ckpt'))
         model_paths = [str(model_path) for model_path in model_paths if 'epoch' in str(model_path)]
+        # model_paths = [f'{log_path}/epoch=279-val_loss=0.4009.ckpt']
         for path in model_paths:
             print(path)
             new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg)
@@ -120,6 +146,16 @@ if __name__ == '__main__':
     cfg.Data.fold = args.fold
     cfg.Loss.base_loss = args.loss
     cfg.Data.bag_size = args.bag_size
+    cfg.version = args.version
+
+    log_path = Path(cfg.General.log_path) / str(Path(cfg.config).parent)
+    Path(cfg.General.log_path).mkdir(exist_ok=True, parents=True)
+    log_name =  f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
+    task = '_'.join(Path(cfg.config).name[:-5].split('_')[2:])
+    # task = Path(cfg.config).name[:-5].split('_')[2:][0]
+    cfg.log_path = log_path / f'{cfg.Model.name}' / task / log_name 
+    
+    
     
 
     #---->main
diff --git a/train_loop.py b/train_loop.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9236814589b30d30aab061aaf52599db5a130df
--- /dev/null
+++ b/train_loop.py
@@ -0,0 +1,212 @@
+from pytorch_lightning import LightningModule
+import torch
+import torch.nn.functional as F
+from torchmetrics.classification.accuracy import Accuracy
+import os.path as osp
+from abc import ABC, abstractmethod
+from copy import deepcopy
+from pytorch_lightning import LightningModule
+from pytorch_lightning.loops.base import Loop
+from pytorch_lightning.loops.fit_loop import FitLoop
+from pytorch_lightning.trainer.states import TrainerFn
+from datasets.data_interface import BaseKFoldDataModule
+from typing import Any, Dict, List, Optional, Type
+import torchmetrics
+import pandas as pd
+import matplotlib.pyplot as plt
+import seaborn as sns
+
+
+
+class EnsembleVotingModel(LightningModule):
+    def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str], n_classes, log_path) -> None:
+        super().__init__()
+        # Create `num_folds` models with their associated fold weights
+        self.n_classes = n_classes
+        self.log_path = log_path
+        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+        self.models = torch.nn.ModuleList([model_cls.load_from_checkpoint(p) for p in checkpoint_paths])
+        self.test_acc = Accuracy()
+        if self.n_classes > 2: 
+            self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted')
+            metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes,
+                                                                           average='micro'),
+                                                     torchmetrics.CohenKappa(num_classes = self.n_classes),
+                                                     torchmetrics.F1Score(num_classes = self.n_classes,
+                                                                     average = 'macro'),
+                                                     torchmetrics.Recall(average = 'macro',
+                                                                         num_classes = self.n_classes),
+                                                     torchmetrics.Precision(average = 'macro',
+                                                                            num_classes = self.n_classes),
+                                                     torchmetrics.Specificity(average = 'macro',
+                                                                            num_classes = self.n_classes)])
+                                                                            
+        else : 
+            self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'weighted')
+            metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2,
+                                                                           average = 'micro'),
+                                                     torchmetrics.CohenKappa(num_classes = 2),
+                                                     torchmetrics.F1Score(num_classes = 2,
+                                                                     average = 'macro'),
+                                                     torchmetrics.Recall(average = 'macro',
+                                                                         num_classes = 2),
+                                                     torchmetrics.Precision(average = 'macro',
+                                                                            num_classes = 2)])
+        self.test_metrics = metrics.clone(prefix = 'test_')
+        self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes)
+
+    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
+        # Compute the averaged predictions over the `num_folds` models.
+        # print(batch[0].shape)
+        input, label, _ = batch
+        label = label.float()
+        input = input.squeeze(0).float()
+
+            
+        logits = torch.stack([m(input) for m in self.models]).mean(0)
+        Y_hat = torch.argmax(logits, dim=1)
+        Y_prob = F.softmax(logits, dim = 1)
+        # #---->acc log
+        Y = torch.argmax(label)
+        self.data[Y]["count"] += 1
+        self.data[Y]["correct"] += (Y_hat.item() == Y)
+
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label}
+
+    def test_epoch_end(self, output_results):
+        probs = torch.cat([x['Y_prob'] for x in output_results])
+        max_probs = torch.stack([x['Y_hat'] for x in output_results])
+        # target = torch.stack([x['label'] for x in output_results], dim = 0)
+        target = torch.cat([x['label'] for x in output_results])
+        target = torch.argmax(target, dim=1)
+        
+        #---->
+        auc = self.AUROC(probs, target.squeeze())
+        metrics = self.test_metrics(max_probs.squeeze() , target)
+
+
+        # metrics = self.test_metrics(max_probs.squeeze() , torch.argmax(target.squeeze(), dim=1))
+        metrics['test_auc'] = auc
+
+        # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True)
+
+        # print(max_probs.squeeze(0).shape)
+        # print(target.shape)
+        # self.log_dict(metrics, logger = True)
+        for keys, values in metrics.items():
+            print(f'{keys} = {values}')
+            metrics[keys] = values.cpu().numpy()
+        #---->acc log
+        for c in range(self.n_classes):
+            count = self.data[c]["count"]
+            correct = self.data[c]["correct"]
+            if count == 0: 
+                acc = None
+            else:
+                acc = float(correct) / count
+            print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
+        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+
+        self.log_confusion_matrix(probs, target, stage='test')
+        #---->
+        result = pd.DataFrame([metrics])
+        result.to_csv(self.log_path / 'result.csv')
+
+
+    def log_confusion_matrix(self, max_probs, target, stage):
+            confmat = self.confusion_matrix(max_probs.squeeze(), target)
+            df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
+            plt.figure()
+            fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
+            # plt.close(fig_)
+            # plt.savefig(f'{self.log_path}/cm_e{self.current_epoch}')
+            self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+
+            if stage == 'test':
+                plt.savefig(f'{self.log_path}/cm_test')
+            plt.close(fig_)
+
+class KFoldLoop(Loop):
+    def __init__(self, num_folds: int, export_path: str, **kargs) -> None:
+        super().__init__()
+        self.num_folds = num_folds
+        self.current_fold: int = 0
+        self.export_path = export_path
+        self.n_classes = kargs["model"].n_classes
+        self.log_path = kargs["log"]
+
+    @property
+    def done(self) -> bool:
+        return self.current_fold >= self.num_folds
+
+    def connect(self, fit_loop: FitLoop) -> None:
+        self.fit_loop = fit_loop
+
+    def reset(self) -> None:
+        """Nothing to reset in this loop."""
+
+    def on_run_start(self, *args: Any, **kwargs: Any) -> None:
+        """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the
+        model."""
+        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
+        self.trainer.datamodule.setup_folds(self.num_folds)
+        self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict())
+
+    def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
+        """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance."""
+        print(f"STARTING FOLD {self.current_fold}")
+        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
+        self.trainer.datamodule.setup_fold_index(self.current_fold)
+
+    def advance(self, *args: Any, **kwargs: Any) -> None:
+        """Used to the run a fitting and testing on the current hold."""
+        self._reset_fitting()  # requires to reset the tracking stage.
+        self.fit_loop.run()
+
+        self._reset_testing()  # requires to reset the tracking stage.
+        self.trainer.test_loop.run()
+        self.current_fold += 1  # increment fold tracking number.
+
+    def on_advance_end(self) -> None:
+        """Used to save the weights of the current fold and reset the LightningModule and its optimizers."""
+        self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt"))
+        # restore the original weights + optimizers and schedulers.
+        self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict)
+        self.trainer.strategy.setup_optimizers(self.trainer)
+        self.replace(fit_loop=FitLoop)
+
+    def on_run_end(self) -> None:
+        """Used to compute the performance of the ensemble model on the test set."""
+        checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)]
+        voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths, n_classes=self.n_classes, log_path=self.log_path)
+        voting_model.trainer = self.trainer
+        # This requires to connect the new model and move it the right device.
+        self.trainer.strategy.connect(voting_model)
+        self.trainer.strategy.model_to_device()
+        self.trainer.test_loop.run()
+
+    def on_save_checkpoint(self) -> Dict[str, int]:
+        return {"current_fold": self.current_fold}
+
+    def on_load_checkpoint(self, state_dict: Dict) -> None:
+        self.current_fold = state_dict["current_fold"]
+
+    def _reset_fitting(self) -> None:
+        self.trainer.reset_train_dataloader()
+        self.trainer.reset_val_dataloader()
+        self.trainer.state.fn = TrainerFn.FITTING
+        self.trainer.training = True
+
+    def _reset_testing(self) -> None:
+        self.trainer.reset_test_dataloader()
+        self.trainer.state.fn = TrainerFn.TESTING
+        self.trainer.testing = True
+
+    def __getattr__(self, key) -> Any:
+        # requires to be overridden as attributes of the wrapped loop are being accessed.
+        if key not in self.__dict__:
+            return getattr(self.fit_loop, key)
+        return self.__dict__[key]
+
+    def __setstate__(self, state: Dict[str, Any]) -> None:
+        self.__dict__.update(state)
\ No newline at end of file
diff --git a/utils/__pycache__/utils.cpython-39.pyc b/utils/__pycache__/utils.cpython-39.pyc
index 704aade8ba3e8bcbedf237b3fde5db80b84fcd32..33d3d005a578948dd3d5b58938a815f86c4e92f5 100644
Binary files a/utils/__pycache__/utils.cpython-39.pyc and b/utils/__pycache__/utils.cpython-39.pyc differ
diff --git a/utils/utils.py b/utils/utils.py
index 010eaab0c51e6fc784774502d85e7eef7303487e..814cf1df0d1258fa19663208e91cd23ece06f6eb 100755
--- a/utils/utils.py
+++ b/utils/utils.py
@@ -10,10 +10,11 @@ from torch.utils.data.dataset import Dataset, Subset
 from torchmetrics.classification.accuracy import Accuracy
 
 from pytorch_lightning import LightningDataModule, seed_everything, Trainer
-from pytorch_lightning.core.module import LightningModule
+from pytorch_lightning import LightningModule
 from pytorch_lightning.loops.base import Loop
 from pytorch_lightning.loops.fit_loop import FitLoop
 from pytorch_lightning.trainer.states import TrainerFn
+from typing import Any, Dict, List, Optional, Type
 
 #---->read yaml
 import yaml
@@ -27,30 +28,30 @@ def read_yaml(fpath=None):
 from pytorch_lightning import loggers as pl_loggers
 def load_loggers(cfg):
 
-    log_path = cfg.General.log_path
-    Path(log_path).mkdir(exist_ok=True, parents=True)
-    log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
-    version_name = Path(cfg.config).name[:-5]
+    # log_path = cfg.General.log_path
+    # Path(log_path).mkdir(exist_ok=True, parents=True)
+    # log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
+    # version_name = Path(cfg.config).name[:-5]
     
     
     #---->TensorBoard
     if cfg.stage != 'test':
-        cfg.log_path = Path(log_path) / log_name / version_name / f'fold{cfg.Data.fold}'
-        tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name),
-                                                name = version_name, version = f'fold{cfg.Data.fold}',
+        
+        tb_logger = pl_loggers.TensorBoardLogger(cfg.log_path,
+                                                  # version = f'fold{cfg.Data.fold}'
                                                 log_graph = True, default_hp_metric = False)
         #---->CSV
-        csv_logger = pl_loggers.CSVLogger(log_path+str(log_name),
-                                        name = version_name, version = f'fold{cfg.Data.fold}', )
+        csv_logger = pl_loggers.CSVLogger(cfg.log_path,
+                                        ) # version = f'fold{cfg.Data.fold}', 
     else:  
-        cfg.log_path = Path(log_path) / log_name / version_name / f'test'
-        tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name),
-                                                name = version_name, version = f'test',
+        cfg.log_path = Path(cfg.log_path) / f'test'
+        tb_logger = pl_loggers.TensorBoardLogger(cfg.log_path,
+                                                version = f'test',
                                                 log_graph = True, default_hp_metric = False)
         #---->CSV
-        csv_logger = pl_loggers.CSVLogger(log_path+str(log_name),
-                                        name = version_name, version = f'test', )
-                                    
+        csv_logger = pl_loggers.CSVLogger(cfg.log_path,
+                                        version = f'test', )
+                              
     
     print(f'---->Log dir: {cfg.log_path}')
 
@@ -63,11 +64,11 @@ from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
 from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
 from pytorch_lightning.callbacks.early_stopping import EarlyStopping
 
-def load_callbacks(cfg):
+def load_callbacks(cfg, save_path):
 
     Mycallbacks = []
     # Make output path
-    output_path = cfg.log_path
+    output_path = save_path / 'checkpoints' 
     output_path.mkdir(exist_ok=True, parents=True)
 
     early_stop_callback = EarlyStopping(
@@ -94,8 +95,9 @@ def load_callbacks(cfg):
     Mycallbacks.append(progress_bar)
 
     if cfg.General.server == 'train' :
+        # save_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.resume_version}' / last.ckpt
         Mycallbacks.append(ModelCheckpoint(monitor = 'val_loss',
-                                         dirpath = str(cfg.log_path),
+                                         dirpath = str(output_path),
                                          filename = '{epoch:02d}-{val_loss:.4f}',
                                          verbose = True,
                                          save_last = True,
@@ -103,7 +105,7 @@ def load_callbacks(cfg):
                                          mode = 'min',
                                          save_weights_only = True))
         Mycallbacks.append(ModelCheckpoint(monitor = 'val_auc',
-                                         dirpath = str(cfg.log_path),
+                                         dirpath = str(output_path),
                                          filename = '{epoch:02d}-{val_auc:.4f}',
                                          verbose = True,
                                          save_last = True,
@@ -136,87 +138,3 @@ def convert_labels_for_task(task, label):
     return label_map[task][label]
 
 
-#-----> KFOLD LOOP
-
-class KFoldLoop(Loop):
-    def __init__(self, num_folds: int, export_path: str) -> None:
-        super().__init__()
-        self.num_folds = num_folds
-        self.current_fold: int = 0
-        self.export_path = export_path
-
-    @property
-    def done(self) -> bool:
-        return self.current_fold >= self.num_folds
-
-    def connect(self, fit_loop: FitLoop) -> None:
-        self.fit_loop = fit_loop
-
-    def reset(self) -> None:
-        """Nothing to reset in this loop."""
-
-    def on_run_start(self, *args: Any, **kwargs: Any) -> None:
-        """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the
-        model."""
-        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
-        self.trainer.datamodule.setup_folds(self.num_folds)
-        self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict())
-
-    def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
-        """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance."""
-        print(f"STARTING FOLD {self.current_fold}")
-        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
-        self.trainer.datamodule.setup_fold_index(self.current_fold)
-
-    def advance(self, *args: Any, **kwargs: Any) -> None:
-        """Used to the run a fitting and testing on the current hold."""
-        self._reset_fitting()  # requires to reset the tracking stage.
-        self.fit_loop.run()
-
-        self._reset_testing()  # requires to reset the tracking stage.
-        self.trainer.test_loop.run()
-        self.current_fold += 1  # increment fold tracking number.
-
-    def on_advance_end(self) -> None:
-        """Used to save the weights of the current fold and reset the LightningModule and its optimizers."""
-        self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt"))
-        # restore the original weights + optimizers and schedulers.
-        self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict)
-        self.trainer.strategy.setup_optimizers(self.trainer)
-        self.replace(fit_loop=FitLoop)
-
-    def on_run_end(self) -> None:
-        """Used to compute the performance of the ensemble model on the test set."""
-        checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)]
-        voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths)
-        voting_model.trainer = self.trainer
-        # This requires to connect the new model and move it the right device.
-        self.trainer.strategy.connect(voting_model)
-        self.trainer.strategy.model_to_device()
-        self.trainer.test_loop.run()
-
-    def on_save_checkpoint(self) -> Dict[str, int]:
-        return {"current_fold": self.current_fold}
-
-    def on_load_checkpoint(self, state_dict: Dict) -> None:
-        self.current_fold = state_dict["current_fold"]
-
-    def _reset_fitting(self) -> None:
-        self.trainer.reset_train_dataloader()
-        self.trainer.reset_val_dataloader()
-        self.trainer.state.fn = TrainerFn.FITTING
-        self.trainer.training = True
-
-    def _reset_testing(self) -> None:
-        self.trainer.reset_test_dataloader()
-        self.trainer.state.fn = TrainerFn.TESTING
-        self.trainer.testing = True
-
-    def __getattr__(self, key) -> Any:
-        # requires to be overridden as attributes of the wrapped loop are being accessed.
-        if key not in self.__dict__:
-            return getattr(self.fit_loop, key)
-        return self.__dict__[key]
-
-    def __setstate__(self, state: Dict[str, Any]) -> None:
-        self.__dict__.update(state)
\ No newline at end of file