diff --git a/DeepGraft/TransMIL_simple.yaml b/DeepGraft/Resnet50.yaml
similarity index 93%
rename from DeepGraft/TransMIL_simple.yaml
rename to DeepGraft/Resnet50.yaml
index 4501f2c06242bb6cb951fef458eab000daac0d75..e6b780b18947b315ee77bb0c6b8b0dbd433fa249 100644
--- a/DeepGraft/TransMIL_simple.yaml
+++ b/DeepGraft/Resnet50.yaml
@@ -5,7 +5,7 @@ General:
     amp_level: O2
     precision: 16 
     multi_gpu_mode: dp
-    gpus: [1]
+    gpus: [0]
     epochs: &epoch 200 
     grad_acc: 2
     frozen_bn: False
@@ -32,9 +32,8 @@ Data:
         
 
 Model:
-    name: TransMIL
+    name: resnet50
     n_classes: 2
-    backbone: simple
 
 
 Optimizer:
diff --git a/DeepGraft/TransMIL_debug.yaml b/DeepGraft/TransMIL_debug.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d83ce0d902228644d4fdd8a3059cd4b135a69fde
--- /dev/null
+++ b/DeepGraft/TransMIL_debug.yaml
@@ -0,0 +1,50 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [1]
+    epochs: &epoch 200 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 200
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_debug.json'
+    fold: 0
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+        
+
+Model:
+    name: TransMIL
+    n_classes: 2
+    backbone: resnet18
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_dino.yaml b/DeepGraft/TransMIL_dino.yaml
index ffe987c913ff9075730d7b16b8e9dba16d5d4978..b7161ba63b98a38a09a8a545e0f661b13b342914 100644
--- a/DeepGraft/TransMIL_dino.yaml
+++ b/DeepGraft/TransMIL_dino.yaml
@@ -6,10 +6,10 @@ General:
     precision: 16 
     multi_gpu_mode: dp
     gpus: [0]
-    epochs: &epoch 1000 
+    epochs: &epoch 200 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 50
     server: test #train #test
     log_path: logs/
 
@@ -17,7 +17,7 @@ Data:
     dataset_name: custom
     data_shuffle: False
     data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
-    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_bin.json'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
     fold: 0
     nfold: 4
 
diff --git a/DeepGraft/TransMIL_efficientnet_no_other.yaml b/DeepGraft/TransMIL_efficientnet_no_other.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..79d8ea88865ebddd6a918bdc4b9435b6ba973ff1
--- /dev/null
+++ b/DeepGraft/TransMIL_efficientnet_no_other.yaml
@@ -0,0 +1,48 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 1000 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 200
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_other.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 5
+    backbone: efficientnet
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0001
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_efficientnet_no_viral.yaml b/DeepGraft/TransMIL_efficientnet_no_viral.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8780060ebd1475b273b06498800523d7108085b1
--- /dev/null
+++ b/DeepGraft/TransMIL_efficientnet_no_viral.yaml
@@ -0,0 +1,48 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [3]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 200
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 4
+    backbone: efficientnet
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f69b5bfadaa3023c9b6113e59681daee73b29fe2
--- /dev/null
+++ b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml
@@ -0,0 +1,48 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [3]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 200
+    server: train #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 2
+    backbone: efficientnet
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_resnet18_all.yaml b/DeepGraft/TransMIL_resnet18_all.yaml
index 8fa5818981b31c1a6c36f34b42e055f2198681fc..f331e4ee092dd8a736e56ff1737514dffd2d1ad2 100644
--- a/DeepGraft/TransMIL_resnet18_all.yaml
+++ b/DeepGraft/TransMIL_resnet18_all.yaml
@@ -9,7 +9,7 @@ General:
     epochs: &epoch 500 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 50
     server: test #train #test
     log_path: logs/
 
diff --git a/DeepGraft/TransMIL_resnet18_no_other.yaml b/DeepGraft/TransMIL_resnet18_no_other.yaml
index 95a9bd64692f5ee12dd822e04fea889b80717457..c7a27f228ec2cbeb77a9456af063f8868c9e877f 100644
--- a/DeepGraft/TransMIL_resnet18_no_other.yaml
+++ b/DeepGraft/TransMIL_resnet18_no_other.yaml
@@ -9,7 +9,7 @@ General:
     epochs: &epoch 500 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 50
     server: test #train #test
     log_path: logs/
 
diff --git a/DeepGraft/TransMIL_resnet18_no_viral.yaml b/DeepGraft/TransMIL_resnet18_no_viral.yaml
index 155b676a24e0541f42f8fee12af2d987217c0525..93054bf660eaecfdc7fd8e9ac7ceecf8007b12b9 100644
--- a/DeepGraft/TransMIL_resnet18_no_viral.yaml
+++ b/DeepGraft/TransMIL_resnet18_no_viral.yaml
@@ -9,7 +9,7 @@ General:
     epochs: &epoch 500 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 50
     server: test #train #test
     log_path: logs/
 
diff --git a/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml
index e7d9bf0694a227f987d1d2fbf2f0facb53c248d5..c26e1e9ff0329f32efbfd96334c6d1ac957d90bb 100644
--- a/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml
+++ b/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml
@@ -9,7 +9,7 @@ General:
     epochs: &epoch 500 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 50
     server: test #train #test
     log_path: logs/
 
diff --git a/DeepGraft/TransMIL_resnet50_all.yaml b/DeepGraft/TransMIL_resnet50_all.yaml
index eba3a4fa20870ad4fc2b173ccb4e60086ddb3ac5..e6959eac839447b10e1a89ef3f09001ab4318594 100644
--- a/DeepGraft/TransMIL_resnet50_all.yaml
+++ b/DeepGraft/TransMIL_resnet50_all.yaml
@@ -9,7 +9,7 @@ General:
     epochs: &epoch 1000 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 50
     server: train #train #test
     log_path: logs/
 
diff --git a/DeepGraft/TransMIL_resnet50_no_other.yaml b/DeepGraft/TransMIL_resnet50_no_other.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d3cd2aa27224aeecd541edff40643dda60b4609e
--- /dev/null
+++ b/DeepGraft/TransMIL_resnet50_no_other.yaml
@@ -0,0 +1,48 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_other.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 5
+    backbone: resnet50
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_resnet50_no_viral.yaml b/DeepGraft/TransMIL_resnet50_no_viral.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2e3394a8342caab9824528b5ec254bc6c91b7ba3
--- /dev/null
+++ b/DeepGraft/TransMIL_resnet50_no_viral.yaml
@@ -0,0 +1,48 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [3]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 4
+    backbone: resnet50
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a756616cd4a128874731b17353d108856f72e9f4
--- /dev/null
+++ b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
@@ -0,0 +1,48 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [3]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 2
+    backbone: resnet50
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/MyBackbone/__init__.py b/MyBackbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fca298edf4ac7bdfb723bcaddf484dad284ae25
--- /dev/null
+++ b/MyBackbone/__init__.py
@@ -0,0 +1,2 @@
+
+from .backbone_factory import init_backbone
\ No newline at end of file
diff --git a/MyBackbone/backbone_factory.py b/MyBackbone/backbone_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff770e583fc3ddc4712424f9381ed9adeb8b7742
--- /dev/null
+++ b/MyBackbone/backbone_factory.py
@@ -0,0 +1,80 @@
+import torch
+import torch.nn as nn
+
+from transformers import AutoFeatureExtractor, ViTModel
+from torchvision import models
+
+def init_backbone(**kargs):
+    
+    backbone = kargs['backbone']
+    n_classes = kargs['n_classes']
+    out_features = kargs['out_features']
+
+    if backbone == 'dino' or backbone == 'vit':
+
+        if backbone == 'dino':        
+            feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
+            model_ft = ViTModel.from_pretrained('facebook/dino-vitb16', num_labels=n_classes)
+
+        def model_ft(input):
+            input = feature_extractor(input, return_tensors='pt')
+            features = model_ft(**input)
+
+
+    elif kargs['backbone'] == 'resnet18':
+        resnet18 = models.resnet18(pretrained=True)
+        modules = list(resnet18.children())[:-1]
+        # model_ft.fc = nn.Linear(512, out_features)
+
+        res18 = nn.Sequential(
+            *modules,
+        )
+        for param in res18.parameters():
+            param.requires_grad = False
+        self.model_ft = nn.Sequential(
+            res18,
+            nn.AdaptiveAvgPool2d(1),
+            View((-1, 512)),
+            nn.Linear(512, self.out_features),
+            nn.GELU(),
+        )
+    elif kargs['backbone'] == 'resnet50':
+
+        resnet50 = models.resnet50(pretrained=True)    
+        # model_ft.fc = nn.Linear(1024, out_features)
+        modules = list(resnet50.children())[:-3]
+        res50 = nn.Sequential(
+            *modules,     
+        )
+        for param in res50.parameters():
+            param.requires_grad = False
+        self.model_ft = nn.Sequential(
+            res50,
+            nn.AdaptiveAvgPool2d(1),
+            View((-1, 1024)),
+            nn.Linear(1024, self.out_features),
+            nn.GELU()
+        )
+    elif kargs['backbone'] == 'efficientnet':
+        efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_widese_b0', pretrained=True)
+        for param in efficientnet.parameters():
+            param.requires_grad = False
+        self.model_ft = nn.Sequential(
+            efficientnet,
+            nn.Linear(1000, 512),
+            nn.GELU(),
+        )
+    elif kargs['backbone'] == 'simple': #mil-ab attention
+        feature_extracting = False
+        self.model_ft = nn.Sequential(
+            nn.Conv2d(3, 20, kernel_size=5),
+            nn.ReLU(),
+            nn.MaxPool2d(2, stride=2),
+            nn.Conv2d(20, 50, kernel_size=5),
+            nn.ReLU(),
+            nn.MaxPool2d(2, stride=2),
+            View((-1, 1024)),
+            nn.Linear(1024, self.out_features),
+            nn.ReLU(),
+        )
+
diff --git a/MyLoss/__pycache__/loss_factory.cpython-39.pyc b/MyLoss/__pycache__/loss_factory.cpython-39.pyc
index ed7437016bbcadc48f9bb0a97401529dc574c9b2..14452dde7b34fd11b5255f4ed07bdeb48a27e49f 100644
Binary files a/MyLoss/__pycache__/loss_factory.cpython-39.pyc and b/MyLoss/__pycache__/loss_factory.cpython-39.pyc differ
diff --git a/MyLoss/__pycache__/poly_loss.cpython-39.pyc b/MyLoss/__pycache__/poly_loss.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08edf030976f28689763861889e7ace3d3747a59
Binary files /dev/null and b/MyLoss/__pycache__/poly_loss.cpython-39.pyc differ
diff --git a/MyLoss/loss_factory.py b/MyLoss/loss_factory.py
index 1dffa6182ef40b1f46436b8a1eadb0a8906c17ff..f3bdcebb96480a0ed6de1e33a0de2defaf1792ea 100755
--- a/MyLoss/loss_factory.py
+++ b/MyLoss/loss_factory.py
@@ -13,6 +13,7 @@ from .hausdorff import HausdorffDTLoss, HausdorffERLoss
 from .lovasz_loss import LovaszSoftmax
 from .ND_Crossentropy import CrossentropyND, TopKLoss, WeightedCrossEntropyLoss,\
      WeightedCrossEntropyLossV2, DisPenalizedCE
+from .poly_loss import PolyLoss
 
 from pytorch_toolbelt import losses as L
 
@@ -22,7 +23,7 @@ def create_loss(args, w1=1.0, w2=0.5):
     # mode = args.base_loss #BINARY_MODE \MULTICLASS_MODE \MULTILABEL_MODE 
     loss = None
     if hasattr(nn, conf_loss): 
-        loss = getattr(nn, conf_loss)() 
+        loss = getattr(nn, conf_loss)(label_smoothing=0.5) 
     #binary loss
     elif conf_loss == "focal":
         loss = L.BinaryFocalLoss()
@@ -46,6 +47,8 @@ def create_loss(args, w1=1.0, w2=0.5):
         loss = L.JointLoss(BCEWithLogitsLoss(), L.BinaryDiceLogLoss(), w1, w2)
     elif conf_loss == "reduced_focal":
         loss = L.BinaryFocalLoss(reduced=True)
+    elif conf_loss == "polyloss":
+        loss = PolyLoss(softmax=False)
     else:
         assert False and "Invalid loss"
         raise ValueError
diff --git a/MyLoss/poly_loss.py b/MyLoss/poly_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3705458eff14c788f2d6719b923d6498cc3c74d
--- /dev/null
+++ b/MyLoss/poly_loss.py
@@ -0,0 +1,84 @@
+# From https://github.com/yiyixuxu/polyloss-pytorch
+
+import warnings
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.loss import _Loss
+
+
+def to_one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor:
+    # if `dim` is bigger, add singleton dim at the end
+    if labels.ndim < dim + 1:
+        shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape))
+        labels = torch.reshape(labels, shape)
+
+    sh = list(labels.shape)
+
+    if sh[dim] != 1:
+        raise AssertionError("labels should have a channel with length equal to one.")
+
+    sh[dim] = num_classes
+
+    o = torch.zeros(size=sh, dtype=dtype, device=labels.device)
+    labels = o.scatter_(dim=dim, index=labels.long(), value=1)
+
+    return labels
+
+
+class PolyLoss(_Loss):
+    def __init__(self,
+                 softmax: bool = False,
+                 ce_weight: Optional[torch.Tensor] = None,
+                 reduction: str = 'mean',
+                 epsilon: float = 1.0,
+                 ) -> None:
+        super().__init__()
+        self.softmax = softmax
+        self.reduction = reduction
+        self.epsilon = epsilon
+        self.cross_entropy = nn.CrossEntropyLoss(weight=ce_weight, reduction='none')
+
+    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            input: the shape should be BNH[WD], where N is the number of classes.
+                You can pass logits or probabilities as input, if pass logit, must set softmax=True
+            target: the shape should be BNH[WD] (one-hot format) or B1H[WD], where N is the number of classes.
+                It should contain binary values
+        Raises:
+            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
+       """
+        n_pred_ch, n_target_ch = input.shape[1], target.shape[1]
+        # target not in one-hot encode format, has shape B1H[WD]
+        if n_pred_ch != n_target_ch:
+            # squeeze out the channel dimension of size 1 to calculate ce loss
+            self.ce_loss = self.cross_entropy(input, torch.squeeze(target, dim=1).long())
+            # convert into one-hot format to calculate ce loss
+            target = to_one_hot(target, num_classes=n_pred_ch)
+        else:
+            # # target is in the one-hot format, convert to BH[WD] format to calculate ce loss
+            self.ce_loss = self.cross_entropy(input, torch.argmax(target, dim=1))
+
+        if self.softmax:
+            if n_pred_ch == 1:
+                warnings.warn("single channel prediction, `softmax=True` ignored.")
+            else:
+                input = torch.softmax(input, 1)
+
+        pt = (input * target).sum(dim=1)  # BH[WD]
+        poly_loss = self.ce_loss + self.epsilon * (1 - pt)
+
+        if self.reduction == 'mean':
+            polyl = torch.mean(poly_loss)  # the batch and channel average
+        elif self.reduction == 'sum':
+            polyl = torch.sum(poly_loss)  # sum over the batch and channel dims
+        elif self.reduction == 'none':
+            # If we are not computing voxelwise loss components at least
+            # make sure a none reduction maintains a broadcastable shape
+            # BH[WD] -> BH1[WD]
+            polyl = poly_loss.unsqueeze(1)
+        else:
+            raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
+        return (polyl)
\ No newline at end of file
diff --git a/README.md b/README.md
index 04c7dbabf81b17979326fdff17670a1b5169e596..effc0fe369c45153ec681403569c710ea114baf3 100644
--- a/README.md
+++ b/README.md
@@ -9,3 +9,7 @@ python train.py --stage='train' --config='Camelyon/TransMIL.yaml'  --gpus=0 --fo
 ```python
 python train.py --stage='test' --config='Camelyon/TransMIL.yaml'  --gpus=0 --fold=0
 ```
+
+
+### Changes Made: 
+
diff --git a/datasets/__pycache__/custom_dataloader.cpython-39.pyc b/datasets/__pycache__/custom_dataloader.cpython-39.pyc
index 147b3fc3c4628d6eed5bebd6ff248d31aaeebb34..4a200bbb7d34328019d9bb9e604b0524127ae863 100644
Binary files a/datasets/__pycache__/custom_dataloader.cpython-39.pyc and b/datasets/__pycache__/custom_dataloader.cpython-39.pyc differ
diff --git a/datasets/__pycache__/data_interface.cpython-39.pyc b/datasets/__pycache__/data_interface.cpython-39.pyc
index 4af0141496c11d6f97255add8746fa0067e55636..9550db1509e8d477a1c15534a76c6f87976fbec4 100644
Binary files a/datasets/__pycache__/data_interface.cpython-39.pyc and b/datasets/__pycache__/data_interface.cpython-39.pyc differ
diff --git a/datasets/custom_dataloader.py b/datasets/custom_dataloader.py
index ddc2ed36c64fc46f0844673514b51cead0052bc6..02850f5db5e263a9d5cdbbfff923952dd5dbfa52 100644
--- a/datasets/custom_dataloader.py
+++ b/datasets/custom_dataloader.py
@@ -9,12 +9,25 @@ from torch.utils.data.dataloader import DataLoader
 from tqdm import tqdm
 # from histoTransforms import RandomHueSaturationValue
 import torchvision.transforms as transforms
+import torchvision
 import torch.nn.functional as F
 import csv
 from PIL import Image
 import cv2
 import pandas as pd
 import json
+import albumentations as A
+from albumentations.pytorch import ToTensorV2
+from transformers import AutoFeatureExtractor
+from imgaug import augmenters as iaa
+import imgaug as ia
+from torchsampler import ImbalancedDatasetSampler
+
+
+class RangeNormalization(object):
+    def __call__(self, sample):
+        img = sample
+        return (img / 255.0 - 0.5) / 0.5
 
 class HDF5MILDataloader(data.Dataset):
     """Represents an abstract HDF5 dataset. For single H5 container! 
@@ -28,68 +41,89 @@ class HDF5MILDataloader(data.Dataset):
         data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
 
     """
-    def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=20):
+    def __init__(self, file_path, label_path, mode, n_classes, backbone=None, load_data=False, data_cache_size=20, bag_size=1024):
         super().__init__()
 
         self.data_info = []
         self.data_cache = {}
         self.slideLabelDict = {}
+        self.files = []
         self.data_cache_size = data_cache_size
         self.mode = mode
         self.file_path = file_path
         # self.csv_path = csv_path
         self.label_path = label_path
         self.n_classes = n_classes
-        self.bag_size = 120
+        self.bag_size = bag_size
+        self.backbone = backbone
         # self.label_file = label_path
         recursive = True
+        
 
         # read labels and slide_path from csv
-
-        # df = pd.read_csv(self.csv_path)
-        # labels = df.LABEL
-        # slides = df.FILENAME
         with open(self.label_path, 'r') as f:
-            self.slideLabelDict = json.load(f)[mode]
-
-        self.slideLabelDict = {Path(x).stem : y for (x,y) in self.slideLabelDict}
-
-            
-        # if Path(slides[0]).suffix:
-        #     slides = list(map(lambda x: Path(x).stem, slides))
-
-        # print(labels)
-        # print(slides)
-        # self.slideLabelDict = dict(zip(slides, labels))
-        # print(self.slideLabelDict)
-
-        #check if files in slideLabelDict, only take files that are available.
+            temp_slide_label_dict = json.load(f)[mode]
+            for (x, y) in temp_slide_label_dict:
+                x = Path(x).stem 
+                x_complete_path = Path(self.file_path)/Path(x + '.hdf5')
+                if x_complete_path.is_file():
+                    self.slideLabelDict[x] = y
+                    self.files.append(x_complete_path)
 
-        files_in_path = list(Path(self.file_path).rglob('*.hdf5'))
-        files_in_path = [x.stem for x in files_in_path]
-        # print(len(files_in_path))
-        # print(files_in_path)
-        # print(list(self.slideLabelDict.keys()))
-        # for x in list(self.slideLabelDict.keys()):
-        #     if x in files_in_path:
-        #         path = Path(self.file_path) / (x + '.hdf5')
-        #         print(path)
-
-        self.files = [Path(self.file_path)/ (x + '.hdf5') for x in list(self.slideLabelDict.keys()) if x in files_in_path]
-
-        print(len(self.files))
-        # self.files = list(map(lambda x: Path(self.file_path) / (Path(x).stem + '.hdf5'), list(self.slideLabelDict.keys())))
 
         for h5dataset_fp in tqdm(self.files):
-            # print(h5dataset_fp)
             self._add_data_infos(str(h5dataset_fp.resolve()), load_data)
 
-        # print(self.data_info)
-        self.resize_transforms = transforms.Compose([
-            transforms.ToPILImage(),
-            transforms.Resize(256),
+
+        self.resize_transforms = A.Compose([
+            A.SmallestMaxSize(max_size=256)
         ])
+        sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1")
+        sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2")
+        sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3")
+        sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4")
+        sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5")
+
+        self.train_transforms = iaa.Sequential([
+            iaa.AddToHueAndSaturation(value=(-13, 13), name="MyHSV"),
+            sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")),
+            iaa.Fliplr(0.5, name="MyFlipLR"),
+            iaa.Flipud(0.5, name="MyFlipUD"),
+            sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")),
+            iaa.OneOf([
+                sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")),
+                sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")),
+                sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine"))
+            ], name="MyOneOf")
+
+        ], name="MyAug")
+
+        # self.train_transforms = A.Compose([
+        #     A.HueSaturationValue(hue_shift_limit=13, sat_shift_limit=2, val_shift_limit=0, always_apply=True, p=1.0),
+        #     # A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=0, val_shift_limit=0, always_apply=False, p=0.5),
+        #     # A.RandomGamma(),
+        #     # A.HorizontalFlip(),
+        #     # A.VerticalFlip(),
+        #     # A.RandomRotate90(),
+        #     # A.OneOf([
+        #     #     A.ElasticTransform(alpha=150, sigma=20, alpha_affine=50),
+        #     #     A.Affine(
+        #     #         scale={'x': (0.95, 1.05), 'y': (0.95, 1.05)},
+        #     #         rotate=(-45, 45),
+        #     #         shear=(-4, 4),
+        #     #         cval=8,
+        #     #         )
+        #     # ]),
+        #     A.Normalize(),
+        #     ToTensorV2(),
+        # ])
+        self.val_transforms = transforms.Compose([
+            # A.Normalize(),
+            # ToTensorV2(),
+            RangeNormalization(),
+            transforms.ToTensor(),
 
+        ])
         self.img_transforms = transforms.Compose([    
             transforms.RandomHorizontalFlip(p=1),
             transforms.RandomVerticalFlip(p=1),
@@ -100,32 +134,45 @@ class HDF5MILDataloader(data.Dataset):
             RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)),
             transforms.ToTensor()
         ])
+        if self.backbone == 'dino':
+            self.feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
 
         # self._add_data_infos(load_data)
 
-
     def __getitem__(self, index):
         # get data
         batch, label, name = self.get_data(index)
         out_batch = []
+        seq_img_d = self.train_transforms.to_deterministic()
         
         if self.mode == 'train':
             # print(img)
             # print(img.shape)
-            for img in batch: 
-                img = self.img_transforms(img)
-                img = self.hsv_transforms(img)
+            for img in batch: # expects numpy 
+                img = img.numpy().astype(np.uint8)
+                if self.backbone == 'dino':
+                    img = self.feature_extractor(images=img, return_tensors='pt')
+                    # img = self.resize_transforms(img)
+                img = seq_img_d.augment_image(img)
+                img = self.val_transforms(img)
                 out_batch.append(img)
 
         else:
             for img in batch:
-                img = transforms.functional.to_tensor(img)
+                img = img.numpy().astype(np.uint8)
+                if self.backbone == 'dino':
+                    img = self.feature_extractor(images=img, return_tensors='pt')
+                    img = self.resize_transforms(img)
+                
+                img = self.val_transforms(img)
                 out_batch.append(img)
+
         if len(out_batch) == 0:
             # print(name)
-            out_batch = torch.randn(100,3,256,256)
+            out_batch = torch.randn(self.bag_size,3,256,256)
         else: out_batch = torch.stack(out_batch)
-        out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch
+        # print(out_batch.shape)
+        # out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch
 
         label = torch.as_tensor(label)
         label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
@@ -138,29 +185,7 @@ class HDF5MILDataloader(data.Dataset):
         wsi_name = Path(file_path).stem
         if wsi_name in self.slideLabelDict:
             label = self.slideLabelDict[wsi_name]
-            wsi_batch = []
-            # with h5py.File(file_path, 'r') as h5_file:
-            #     numKeys = len(h5_file.keys())
-            #     sample = list(h5_file.keys())[0]
-            #     shape = (numKeys,) + h5_file[sample][:].shape
-                # for tile in h5_file.keys():
-                #     img = h5_file[tile][:]
-                    
-                    # print(img)
-                    # if type == 'images':
-                    #     t = 'data'
-                    # else: 
-                    #     t = 'label'
             idx = -1
-            # if load_data: 
-            #     for tile in h5_file.keys():
-            #         img = h5_file[tile][:]
-            #         img = img.astype(np.uint8)
-            #         img = self.resize_transforms(img)
-            #         wsi_batch.append(img)
-            #     idx = self._add_to_cache(wsi_batch, file_path)
-                #     wsi_batch.append(img)
-            # self.data_info.append({'data_path': file_path, 'label': label, 'shape': shape, 'name': wsi_name, 'cache_idx': idx})
             self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx})
 
     def _load_data(self, file_path):
@@ -173,30 +198,15 @@ class HDF5MILDataloader(data.Dataset):
             for tile in h5_file.keys():
                 img = h5_file[tile][:]
                 img = img.astype(np.uint8)
-                img = self.resize_transforms(img)
+                img = torch.from_numpy(img)
+                # img = self.resize_transforms(img)
                 wsi_batch.append(img)
+            wsi_batch = torch.stack(wsi_batch)
+            wsi_batch, _ = to_fixed_size_bag(wsi_batch, self.bag_size)
             idx = self._add_to_cache(wsi_batch, file_path)
             file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path)
             self.data_info[file_idx + idx]['cache_idx'] = idx
 
-            # for type in ['images', 'labels']:
-            #     for key in tqdm(h5_file[f'{self.mode}/{type}'].keys()):
-            #         img = h5_file[data_path][:]
-            #         idx = self._add_to_cache(img, data_path)
-            #         file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == data_path)
-            #         self.data_info[file_idx + idx]['cache_idx'] = idx
-            # for gname, group in h5_file.items():
-            #     for dname, ds in group.items():
-            #         # add data to the data cache and retrieve
-            #         # the cache index
-            #         idx = self._add_to_cache(ds.value, file_path)
-
-            #         # find the beginning index of the hdf5 file we are looking for
-            #         file_idx = next(i for i,v in enumerate(self.data_info) if v['file_path'] == file_path)
-
-            #         # the data info should have the same index since we loaded it in the same way
-            #         self.data_info[file_idx + idx]['cache_idx'] = idx
-
         # remove an element from data cache if size was exceeded
         if len(self.data_cache) > self.data_cache_size:
             # remove one item from the cache at random
@@ -223,6 +233,182 @@ class HDF5MILDataloader(data.Dataset):
     #     data_info_type = [di for di in self.data_info if di['type'] == type]
     #     return data_info_type
 
+    def get_name(self, i):
+        # name = self.get_data_infos(type)[i]['name']
+        name = self.data_info[i]['name']
+        return name
+
+    def get_labels(self, indices):
+
+        return [self.data_info[i]['label'] for i in indices]
+
+    def get_data(self, i):
+        """Call this function anytime you want to access a chunk of data from the
+            dataset. This will make sure that the data is loaded in case it is
+            not part of the data cache.
+            i = index
+        """
+        # fp = self.get_data_infos(type)[i]['data_path']
+        fp = self.data_info[i]['data_path']
+        if fp not in self.data_cache:
+            self._load_data(fp)
+        
+        # get new cache_idx assigned by _load_data_info
+        # cache_idx = self.get_data_infos(type)[i]['cache_idx']
+        cache_idx = self.data_info[i]['cache_idx']
+        label = self.data_info[i]['label']
+        name = self.data_info[i]['name']
+        # print(self.data_cache[fp][cache_idx])
+        return self.data_cache[fp][cache_idx], label, name
+
+class GenesisDataloader(data.Dataset):
+    """Represents an abstract HDF5 dataset. For single H5 container! 
+    
+    Input params:
+        file_path: Path to the folder containing the dataset (one or multiple HDF5 files).
+        mode: 'train' or 'test'
+        load_data: If True, loads all the data immediately into RAM. Use this if
+            the dataset is fits into memory. Otherwise, leave this at false and 
+            the data will load lazily.
+        data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
+
+    """
+    def __init__(self, file_path, label_path, mode, load_data=False, data_cache_size=5, debug=False):
+        super().__init__()
+
+        self.data_info = []
+        self.data_cache = {}
+        self.slideLabelDict = {}
+        self.files = []
+        self.data_cache_size = data_cache_size
+        self.mode = mode
+        self.file_path = file_path
+        # self.csv_path = csv_path
+        self.label_path = label_path
+        # self.n_classes = n_classes
+        self.bag_size = 120
+        # self.transforms = transforms
+        self.input_size = 256
+
+        # for 
+        # self.files = list(Path(self.file_path).rglob('*.hdf5'))
+        home = Path.cwd().parts[1]
+        with open(self.label_path, 'r') as f:
+            temp_slide_label_dict = json.load(f)[mode]
+            for x in temp_slide_label_dict:
+                
+                if Path(x).parts[1] != home: 
+                    x = x.replace(Path(x).parts[1], home)
+                self.files.append(Path(x))
+                # x = Path(x).stem 
+                # x_complete_path = Path(self.file_path)/Path(x + '.hdf5')
+                # if x_complete_path.is_file():
+                #     self.slideLabelDict[x] = y
+                #     self.files.append(x_complete_path)
+
+        for h5dataset_fp in tqdm(self.files):
+            self._add_data_infos(str(h5dataset_fp.resolve()), load_data)
+
+        self.resize_transforms = torchvision.transforms.Compose([
+            torchvision.transforms.ToPILImage(),
+            torchvision.transforms.Resize(self.input_size),
+        ])
+        self.resize_transforms = A.Compose([
+            A.SmallestMaxSize(max_size=256)
+        ])
+        sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1")
+        sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2")
+        sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3")
+        sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4")
+        sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5")
+
+        self.train_transforms = iaa.Sequential([
+            iaa.AddToHueAndSaturation(value=(-13, 13), name="MyHSV"),
+            sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")),
+            iaa.Fliplr(0.5, name="MyFlipLR"),
+            iaa.Flipud(0.5, name="MyFlipUD"),
+            sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")),
+            iaa.OneOf([
+                sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")),
+                sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")),
+                sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine"))
+            ], name="MyOneOf")
+
+        ], name="MyAug")
+        self.val_transforms = transforms.Compose([
+            # A.Normalize(),
+            # ToTensorV2(),
+            RangeNormalization(),
+            transforms.ToTensor(),
+
+        ])
+
+
+    def __getitem__(self, index):
+        # get data
+        img, name = self.get_data(index)
+        # out_batch = []
+        seq_img_d = self.train_transforms.to_deterministic()
+        
+        if self.mode == 'train':
+
+            img = img.numpy().astype(np.uint8)
+            img = seq_img_d.augment_image(img)
+            img = self.val_transforms(img)
+        else:
+            img = img.numpy().astype(np.uint8)
+            img = self.val_transforms(img)
+            # out_batch.append(img)
+
+        return {'data': img, 'label': label}
+        # return out_batch, label
+
+    def __len__(self):
+        return len(self.data_info)
+    
+    def _add_data_infos(self, file_path, load_data):
+        img_name = Path(file_path).stem
+        label = Path(file_path).parts[-2]
+
+        idx = -1
+        self.data_info.append({'data_path': file_path, 'label': label, 'name': img_name, 'cache_idx': idx})
+
+    def _load_data(self, file_path):
+        """Load data to the cache given the file
+        path and update the cache index in the
+        data_info structure.
+        """
+        with h5py.File(file_path, 'r') as h5_file:
+
+            tile = list(h5_file.keys())[0]
+            img = h5_file[tile][:]
+            img = img.astype(np.uint8)
+                # img = self.resize_transforms(img)
+                # wsi_batch.append(img)
+            idx = self._add_to_cache(img, file_path)
+            file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path)
+            self.data_info[file_idx + idx]['cache_idx'] = idx
+
+        # remove an element from data cache if size was exceeded
+        if len(self.data_cache) > self.data_cache_size:
+            # remove one item from the cache at random
+            removal_keys = list(self.data_cache)
+            removal_keys.remove(file_path)
+            self.data_cache.pop(removal_keys[0])
+            # remove invalid cache_idx
+            # self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'shape': di['shape'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
+            self.data_info = [{'data_path': di['data_path'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
+
+    def _add_to_cache(self, data, data_path):
+        """Adds data to the cache and returns its index. There is one cache
+        list for every file_path, containing all datasets in that file.
+        """
+        if data_path not in self.data_cache:
+            self.data_cache[data_path] = [data]
+        else:
+            self.data_cache[data_path].append(data)
+        return len(self.data_cache[data_path]) - 1
+
     def get_name(self, i):
         # name = self.get_data_infos(type)[i]['name']
         name = self.data_info[i]['name']
@@ -248,6 +434,45 @@ class HDF5MILDataloader(data.Dataset):
         return self.data_cache[fp][cache_idx], label, name
 
 
+class RandomHueSaturationValue(object):
+
+    def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5):
+        
+        self.hue_shift_limit = hue_shift_limit
+        self.sat_shift_limit = sat_shift_limit
+        self.val_shift_limit = val_shift_limit
+        self.p = p
+
+    def __call__(self, sample):
+    
+        img = sample #,lbl
+    
+        if np.random.random() < self.p:
+            img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32
+            h, s, v = cv2.split(img)
+            hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1)
+            hue_shift = np.uint8(hue_shift)
+            h += hue_shift
+            sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1])
+            s = cv2.add(s, sat_shift)
+            val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1])
+            v = cv2.add(v, val_shift)
+            img = cv2.merge((h, s, v))
+            img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
+        return img #, lbl
+
+def to_fixed_size_bag(bag: torch.Tensor, bag_size: int = 512):
+
+    # get up to bag_size elements
+    bag_idxs = torch.randperm(bag.shape[0])[:bag_size]
+    bag_samples = bag[bag_idxs]
+    
+    # zero-pad if we don't have enough samples
+    zero_padded = torch.cat((bag_samples,
+                            torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+    return zero_padded, min(bag_size, len(bag))
+
+
 class RandomHueSaturationValue(object):
 
     def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5):
@@ -285,48 +510,81 @@ if __name__ == '__main__':
     train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
     data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
     # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
-    label_path = f'/{home}/ylan/DeepGraft/training_tables/split_Aachen_PAS_all.json'
-    output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/'
-    # os.makedirs(output_path, exist_ok=True)
-
-
-    dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=6)
-    data = DataLoader(dataset, batch_size=1)
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+    output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/augments'
+    os.makedirs(output_path, exist_ok=True)
+
+    n_classes = 2
+
+    dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes, bag_size=20)
+    # print(dataset.dataset)
+    a = int(len(dataset)* 0.8)
+    b = int(len(dataset) - a)
+    train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b])
+    dl = DataLoader(train_ds,  None, sampler=ImbalancedDatasetSampler(train_ds), num_workers=5)
+    dl = DataLoader(train_ds,  None, num_workers=5)
+    
+    # data = DataLoader(dataset, batch_size=1)
 
     # print(len(dataset))
-    x = 0
+    # # x = 0
     c = 0
-    for item in data: 
-        if c >=10:
-            break
+    label_count = [0] *n_classes
+    for item in dl: 
+        # if c >=10:
+        #     break
         bag, label, name = item
-        print(bag)
-        # # print(bag.shape)
-        # if bag.shape[1] == 1:
-        #     print(name)
-        #     print(bag.shape)
+        label_count[np.argmax(label)] += 1
+    print(label_count)
+    print(len(train_ds))
+    #     # # print(bag.shape)
+    #     # if bag.shape[1] == 1:
+    #     #     print(name)
+    #     #     print(bag.shape)
         # print(bag.shape)
-        # print(name)
-        # out_dir = Path(output_path) / name
-        # os.makedirs(out_dir, exist_ok=True)
-
-        # # print(item[2])
-        # # print(len(item))
-        # # print(item[1])
-        # # print(data.shape)
-        # # data = data.squeeze()
-        # bag = item[0]
-        # bag = bag.squeeze()
-        # for i in range(bag.shape[0]):
-        #     img = bag[i, :, :, :]
-        #     img = img.squeeze()
-        #     img = img*255
-        #     img = img.numpy().astype(np.uint8).transpose(1,2,0)
+        
+    #     # out_dir = Path(output_path) / name
+    #     # os.makedirs(out_dir, exist_ok=True)
+
+    #     # # print(item[2])
+    #     # # print(len(item))
+    #     # # print(item[1])
+    #     # # print(data.shape)
+    #     # # data = data.squeeze()
+    #     # bag = item[0]
+    #     bag = bag.squeeze()
+    #     original = original.squeeze()
+    #     for i in range(bag.shape[0]):
+    #         img = bag[i, :, :, :]
+    #         img = img.squeeze()
+            
+    #         img = ((img-img.min())/(img.max() - img.min())) * 255
+    #         print(img)
+    #         # print(img)
+    #         img = img.numpy().astype(np.uint8).transpose(1,2,0)
+
+            
+    #         img = Image.fromarray(img)
+    #         img = img.convert('RGB')
+    #         img.save(f'{output_path}/{i}.png')
+
+
             
-        #     img = Image.fromarray(img)
-        #     img = img.convert('RGB')
-        #     img.save(f'{out_dir}/{i}.png')
-        c += 1
+    #         o_img = original[i,:,:,:]
+    #         o_img = o_img.squeeze()
+    #         print(o_img.shape)
+    #         o_img = ((o_img-o_img.min())/(o_img.max()-o_img.min()))*255
+    #         o_img = o_img.numpy().astype(np.uint8).transpose(1,2,0)
+    #         o_img = Image.fromarray(o_img)
+    #         o_img = o_img.convert('RGB')
+    #         o_img.save(f'{output_path}/{i}_original.png')
+        # c += 1
+    #     break
         # else: break
         # print(data.shape)
-        # print(label)
\ No newline at end of file
+        # print(label)
+    # a = [torch.Tensor((3,256,256))]*3
+    # b = torch.stack(a)
+    # print(b)
+    # c = to_fixed_size_bag(b, 512)
+    # print(c)
\ No newline at end of file
diff --git a/datasets/data_interface.py b/datasets/data_interface.py
index 12a0f8c450b4945658d3566cc5c157116bccab66..056e6ff09bcf768c91e0f51e44640487ea4fbee1 100644
--- a/datasets/data_interface.py
+++ b/datasets/data_interface.py
@@ -8,6 +8,8 @@ from torchvision import transforms
 from .camel_dataloader import FeatureBagLoader
 from .custom_dataloader import HDF5MILDataloader
 from pathlib import Path
+from transformers import AutoFeatureExtractor
+from torchsampler import ImbalancedDatasetSampler
 
 class DataInterface(pl.LightningDataModule):
 
@@ -56,9 +58,10 @@ class DataInterface(pl.LightningDataModule):
                                                 train=True)
             a = int(len(dataset)* 0.8)
             b = int(len(dataset) - a)
-            print(a)
-            print(b)
-            self.train_dataset, self.val_dataset = random_split(dataset, [a, b])
+            # print(a)
+            # print(b)
+            self.train_dataset, self.val_dataset = random_split(dataset, [a, b]) # returns data.Subset
+
             # self.train_dataset = self.instancialize(state='train')
             # self.val_dataset = self.instancialize(state='val')
  
@@ -72,7 +75,7 @@ class DataInterface(pl.LightningDataModule):
 
 
     def train_dataloader(self):
-        return DataLoader(self.train_dataset, batch_size=self.train_batch_size, num_workers=self.train_num_workers, shuffle=True)
+        return DataLoader(self.train_dataset, batch_size=self.train_batch_size, num_workers=self.train_num_workers, shuffle=False)
 
     def val_dataloader(self):
         return DataLoader(self.val_dataset, batch_size=self.train_batch_size, num_workers=self.train_num_workers, shuffle=False)
@@ -106,7 +109,7 @@ class DataInterface(pl.LightningDataModule):
 
 class MILDataModule(pl.LightningDataModule):
 
-    def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, *args, **kwargs):
+    def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, backbone=None, *args, **kwargs):
         super().__init__()
         self.data_root = data_root
         self.label_path = label_path
@@ -121,41 +124,74 @@ class MILDataModule(pl.LightningDataModule):
         self.num_bags_test = 50
         self.seed = 1
 
+        self.backbone = backbone
         self.cache = True
+        self.fe_transform = None
 
 
     def setup(self, stage: Optional[str] = None) -> None:
-        # if self.n_classes == 2:
-        #     if stage in (None, 'fit'):
-        #         dataset = HDF5Dataset(self.data_root, mode='train', n_classes=self.n_classes)
-        #         a = int(len(dataset)* 0.8)
-        #         b = int(len(dataset) - a)
-        #         self.train_data, self.valid_data = random_split(dataset, [a, b])
-
-        #     if stage in (None, 'test'):
-        #         self.test_data = HDF5Dataset(self.data_root, mode='test', n_classes=self.n_classes)
-        # else:
         home = Path.cwd().parts[1]
-        # self.label_path = f'{home}/ylan/DeepGraft_project/code/split_debug.json'
-        # train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train_small.csv'
-        # test_csv = f'/{home}/ylan/DeepGraft_project/code/debug_test_small.csv'
-        
 
         if stage in (None, 'fit'):
-            dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes)
-            # print(len(dataset))
+            dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes, backbone=self.backbone)
             a = int(len(dataset)* 0.8)
             b = int(len(dataset) - a)
             self.train_data, self.valid_data = random_split(dataset, [a, b])
 
         if stage in (None, 'test'):
-            self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes)
+            self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, backbone=self.backbone)
 
         return super().setup(stage=stage)
 
     def train_dataloader(self) -> DataLoader:
-        return DataLoader(self.train_data, self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+        return DataLoader(self.train_data,  self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+        #sampler=ImbalancedDatasetSampler(self.train_data)
+    def val_dataloader(self) -> DataLoader:
+        return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers)
+    
+    def test_dataloader(self) -> DataLoader:
+        return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers)
     
+
+class DataModule(pl.LightningDataModule):
+
+    def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, backbone=None, *args, **kwargs):
+        super().__init__()
+        self.data_root = data_root
+        self.label_path = label_path
+        self.batch_size = batch_size
+        self.num_workers = num_workers
+        self.image_size = 384
+        self.n_classes = n_classes
+        self.target_number = 9
+        self.mean_bag_length = 10
+        self.var_bag_length = 2
+        self.num_bags_train = 200
+        self.num_bags_test = 50
+        self.seed = 1
+
+        self.backbone = backbone
+        self.cache = True
+        self.fe_transform = None
+
+
+    def setup(self, stage: Optional[str] = None) -> None:
+        home = Path.cwd().parts[1]
+        
+        if stage in (None, 'fit'):
+            dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes, backbone=self.backbone)
+            a = int(len(dataset)* 0.8)
+            b = int(len(dataset) - a)
+            self.train_data, self.valid_data = random_split(dataset, [a, b])
+
+        if stage in (None, 'test'):
+            self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, backbone=self.backbone)
+
+        return super().setup(stage=stage)
+
+    def train_dataloader(self) -> DataLoader:
+        return DataLoader(self.train_data,  self.batch_size,  num_workers=self.num_workers, shuffle=False,) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+        #sampler=ImbalancedDatasetSampler(self.train_data),
     def val_dataloader(self) -> DataLoader:
         return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers)
     
diff --git a/fine_tune.py b/fine_tune.py
new file mode 100644
index 0000000000000000000000000000000000000000..5898e755c5d3e4b666599ae324dc81be41425962
--- /dev/null
+++ b/fine_tune.py
@@ -0,0 +1,42 @@
+from transformers import AutoFeatureExtractor, ViTModel
+from transformers import Trainer, TrainingArguments
+from torchvision import models
+import torch
+from datasets.custom_dataloader import DinoDataloader
+
+
+
+def fine_tune_transformer(args):
+
+    data_path = args.data_path
+    model = args.model
+    n_classes = args.n_classes
+
+    if model == 'dino':        
+        feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
+        model_ft = ViTModel.from_pretrained('facebook/dino-vitb16', num_labels=n_classes)
+
+    training_args = TrainingArguments(
+        output_dir = f'logs/fine_tune/{model}',
+        per_device_train_batch_size=16,
+        evaluation_strategy="steps",
+        num_train_epochs=4,
+        fp16=True,
+        save_steps=100,
+        eval_steps=100,
+        logging_steps=10,
+        learning_rate=2e-4,
+        save_total_limit=2,
+        remove_unused_columns=False,
+        push_to_hub=False,
+        report_to='tensorboard',
+        load_best_model_at_end=True,
+    )
+
+    dataset = DinoDataloader(args.data_path, mode='train') #, transforms=transform
+
+    trainer = Trainer(
+        model = model_ft,
+        args=training_args, 
+
+    )
\ No newline at end of file
diff --git a/models/TransMIL.py b/models/TransMIL.py
index ce40a26b37b1886bf5698ee4ab8ecf07c1e4e2c8..69089de0bd934ea38b2854f513295b21cbb73a03 100755
--- a/models/TransMIL.py
+++ b/models/TransMIL.py
@@ -86,7 +86,7 @@ class TransMIL(nn.Module):
         h = self.norm(h)[:,0]
 
         #---->predict
-        logits = self._fc2(torch.sigmoid(h)) #[B, n_classes]
+        logits = self._fc2(h) #[B, n_classes]
         Y_hat = torch.argmax(logits, dim=1)
         Y_prob = F.softmax(logits, dim = 1)
         results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat}
diff --git a/models/__pycache__/TransMIL.cpython-39.pyc b/models/__pycache__/TransMIL.cpython-39.pyc
index 4e1ddff6d6f3cbadfd7f1c0c4686a7806f5896e8..cb0eb6fd5085feda9604ec0f24420d85eb1d939d 100644
Binary files a/models/__pycache__/TransMIL.cpython-39.pyc and b/models/__pycache__/TransMIL.cpython-39.pyc differ
diff --git a/models/__pycache__/model_interface.cpython-39.pyc b/models/__pycache__/model_interface.cpython-39.pyc
index 3e81c0ccdd33e6fe68a5ffe1b602990134944c80..0bf337675c76de4097e7f7723b99de0120f9594c 100644
Binary files a/models/__pycache__/model_interface.cpython-39.pyc and b/models/__pycache__/model_interface.cpython-39.pyc differ
diff --git a/models/__pycache__/resnet50.cpython-39.pyc b/models/__pycache__/resnet50.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cfb660f9829304cb853031dd3d3274396afc38be
Binary files /dev/null and b/models/__pycache__/resnet50.cpython-39.pyc differ
diff --git a/models/model_interface.py b/models/model_interface.py
index 1b0f6e19f8429b764e6fd45b96bf821981d0731e..60b5cc73d6bec977216cdde6bb8bcee3fd267034 100755
--- a/models/model_interface.py
+++ b/models/model_interface.py
@@ -12,18 +12,24 @@ from matplotlib import pyplot as plt
 from MyOptimizer import create_optimizer
 from MyLoss import create_loss
 from utils.utils import cross_entropy_torch
+from timm.loss import AsymmetricLossSingleLabel
+from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
 
 #---->
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torchmetrics
+from torch import optim as optim
 
 #---->
 import pytorch_lightning as pl
 from .vision_transformer import vit_small
 from torchvision import models
 from torchvision.models import resnet
+from transformers import AutoFeatureExtractor, ViTModel
+
+from captum.attr import LayerGradCam
 
 class ModelInterface(pl.LightningModule):
 
@@ -33,13 +39,17 @@ class ModelInterface(pl.LightningModule):
         self.save_hyperparameters()
         self.load_model()
         self.loss = create_loss(loss)
+        # self.asl = AsymmetricLossSingleLabel()
+        # self.loss = LabelSmoothingCrossEntropy(smoothing=0.1)
+        
+        # self.loss = 
         self.optimizer = optimizer
         self.n_classes = model.n_classes
         self.log_path = kargs['log']
 
         #---->acc
         self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
-        
+        # print(self.experiment)
         #---->Metrics
         if self.n_classes > 2: 
             self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted')
@@ -73,35 +83,12 @@ class ModelInterface(pl.LightningModule):
         #--->random
         self.shuffle = kargs['data'].data_shuffle
         self.count = 0
+        self.backbone = kargs['backbone']
 
         self.out_features = 512
         if kargs['backbone'] == 'dino':
-            #---> dino feature extractor
-            arch = 'vit_small'
-            patch_size = 16
-            n_last_blocks = 4
-            # num_labels = 1000
-            avgpool_patchtokens = False
-            home = Path.cwd().parts[1]
-
-            weight_path = f'/{home}/ylan/workspace/dino/output/Aachen_2/checkpoint.pth'
-            model = vit_small(patch_size, num_classes=0)
-            # model.eval()
-            # set_parameter_requires_grad(model, feature_extracting)
-            for param in model.parameters():
-                param.requires_grad = False
-            # print(model.embed_dim)
-            # embed_dim = model.embed_dim * (n_last_blocks + int(avgpool_patchtokens))
-            # model.eval()
-            # print(embed_dim)
-            linear = nn.Linear(model.embed_dim, self.out_features)
-            linear.weight.data.normal_(mean=0.0, std=0.01)
-            linear.bias.data.zero_()
-            
-            self.model_ft = nn.Sequential(
-                model,
-                linear, 
-            )
+            self.feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
+            self.model_ft = ViTModel.from_pretrained('facebook/dino-vitb16')
         elif kargs['backbone'] == 'resnet18':
             resnet18 = models.resnet18(pretrained=True)
             modules = list(resnet18.children())[:-1]
@@ -109,7 +96,6 @@ class ModelInterface(pl.LightningModule):
 
             res18 = nn.Sequential(
                 *modules,
-                
             )
             for param in res18.parameters():
                 param.requires_grad = False
@@ -118,7 +104,7 @@ class ModelInterface(pl.LightningModule):
                 nn.AdaptiveAvgPool2d(1),
                 View((-1, 512)),
                 nn.Linear(512, self.out_features),
-                nn.ReLU(),
+                nn.GELU(),
             )
         elif kargs['backbone'] == 'resnet50':
 
@@ -135,7 +121,17 @@ class ModelInterface(pl.LightningModule):
                 nn.AdaptiveAvgPool2d(1),
                 View((-1, 1024)),
                 nn.Linear(1024, self.out_features),
-                nn.ReLU()
+                # nn.GELU()
+            )
+        elif kargs['backbone'] == 'efficientnet':
+            efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_widese_b0', pretrained=True)
+            for param in efficientnet.parameters():
+                param.requires_grad = False
+            # efn = list(efficientnet.children())[:-1]
+            efficientnet.classifier.fc = nn.Linear(1280, self.out_features)
+            self.model_ft = nn.Sequential(
+                efficientnet,
+                nn.GELU(),
             )
         elif kargs['backbone'] == 'simple': #mil-ab attention
             feature_extracting = False
@@ -151,21 +147,19 @@ class ModelInterface(pl.LightningModule):
                 nn.ReLU(),
             )
 
-    #---->remove v_num
-    # def get_progress_bar_dict(self):
-    #     # don't show the version number
-    #     items = super().get_progress_bar_dict()
-    #     items.pop("v_num", None)
-    #     return items
-
     def training_step(self, batch, batch_idx):
         #---->inference
+        
         data, label, _ = batch
         label = label.float()
-        data = data.squeeze(0)
+        data = data.squeeze(0).float()
+        # print(data)
         # print(data.shape)
-        features = self.model_ft(data)
-        
+        if self.backbone == 'dino':
+            features = self.model_ft(**data)
+            features = features.last_hidden_state
+        else:
+            features = self.model_ft(data)
         features = features.unsqueeze(0)
         # print(features.shape)
         # features = features.squeeze()
@@ -177,21 +171,28 @@ class ModelInterface(pl.LightningModule):
 
         #---->loss
         loss = self.loss(logits, label)
+        # loss = self.asl(logits, label.squeeze())
 
         #---->acc log
         # print(label)
-        Y_hat = int(Y_hat)
+        # Y_hat = int(Y_hat)
         # if self.n_classes == 2:
         #     Y = int(label[0][1])
         # else: 
         Y = torch.argmax(label)
             # Y = int(label[0])
         self.data[Y]["count"] += 1
-        self.data[Y]["correct"] += (Y_hat == Y)
+        self.data[Y]["correct"] += (int(Y_hat) == Y)
 
-        return {'loss': loss} 
+        return {'loss': loss, 'Y_prob': Y_prob, 'Y_hat': Y_hat, 'label': label} 
 
     def training_epoch_end(self, training_step_outputs):
+        # logits = torch.cat([x['logits'] for x in training_step_outputs], dim = 0)
+        probs = torch.cat([x['Y_prob'] for x in training_step_outputs])
+        max_probs = torch.stack([x['Y_hat'] for x in training_step_outputs])
+        # target = torch.stack([x['label'] for x in training_step_outputs], dim = 0)
+        target = torch.cat([x['label'] for x in training_step_outputs])
+        target = torch.argmax(target, dim=1)
         for c in range(self.n_classes):
             count = self.data[c]["count"]
             correct = self.data[c]["correct"]
@@ -202,12 +203,19 @@ class ModelInterface(pl.LightningModule):
             print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
         self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
 
+        # print('max_probs: ', max_probs)
+        # print('probs: ', probs)
+        if self.current_epoch % 10 == 0:
+            self.log_confusion_matrix(probs, target, stage='train')
+
+        self.log('Train/auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True)
+
     def validation_step(self, batch, batch_idx):
 
         data, label, _ = batch
 
         label = label.float()
-        data = data.squeeze(0)
+        data = data.squeeze(0).float()
         features = self.model_ft(data)
         features = features.unsqueeze(0)
 
@@ -224,20 +232,23 @@ class ModelInterface(pl.LightningModule):
         self.data[Y]["count"] += 1
         self.data[Y]["correct"] += (Y_hat.item() == Y)
 
-        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y}
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label}
 
 
     def validation_epoch_end(self, val_step_outputs):
         logits = torch.cat([x['logits'] for x in val_step_outputs], dim = 0)
-        probs = torch.cat([x['Y_prob'] for x in val_step_outputs], dim = 0)
+        # probs = torch.cat([x['Y_prob'] for x in val_step_outputs], dim = 0)
+        probs = torch.cat([x['Y_prob'] for x in val_step_outputs])
         max_probs = torch.stack([x['Y_hat'] for x in val_step_outputs])
-        target = torch.stack([x['label'] for x in val_step_outputs], dim = 0)
+        # target = torch.stack([x['label'] for x in val_step_outputs], dim = 0)
+        target = torch.cat([x['label'] for x in val_step_outputs])
+        target = torch.argmax(target, dim=1)
         #---->
         # logits = logits.long()
         # target = target.squeeze().long()
         # logits = logits.squeeze(0)
         self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True)
-        self.log('auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True)
+        self.log('val_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True)
 
         # print(max_probs.squeeze(0).shape)
         # print(target.shape)
@@ -245,12 +256,8 @@ class ModelInterface(pl.LightningModule):
                           on_epoch = True, logger = True)
 
         #----> log confusion matrix
-        confmat = self.confusion_matrix(max_probs.squeeze(), target)
-        df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
-        plt.figure()
-        fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
-        plt.close(fig_)
-        self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch)
+        self.log_confusion_matrix(probs, target, stage='val')
+        
 
         #---->acc log
         for c in range(self.n_classes):
@@ -267,18 +274,12 @@ class ModelInterface(pl.LightningModule):
         if self.shuffle == True:
             self.count = self.count+1
             random.seed(self.count*50)
-    
-
-
-    def configure_optimizers(self):
-        optimizer = create_optimizer(self.optimizer, self.model)
-        return [optimizer]
 
     def test_step(self, batch, batch_idx):
 
         data, label, _ = batch
         label = label.float()
-        data = data.squeeze(0)
+        data = data.squeeze(0).float()
         features = self.model_ft(data)
         features = features.unsqueeze(0)
 
@@ -292,12 +293,14 @@ class ModelInterface(pl.LightningModule):
         self.data[Y]["count"] += 1
         self.data[Y]["correct"] += (Y_hat.item() == Y)
 
-        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y}
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label}
 
     def test_epoch_end(self, output_results):
-        probs = torch.cat([x['Y_prob'] for x in output_results], dim = 0)
+        probs = torch.cat([x['Y_prob'] for x in output_results])
         max_probs = torch.stack([x['Y_hat'] for x in output_results])
-        target = torch.stack([x['label'] for x in output_results], dim = 0)
+        # target = torch.stack([x['label'] for x in output_results], dim = 0)
+        target = torch.cat([x['label'] for x in output_results])
+        target = torch.argmax(target, dim=1)
         
         #---->
         auc = self.AUROC(probs, target.squeeze())
@@ -326,19 +329,16 @@ class ModelInterface(pl.LightningModule):
             print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
         self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
 
-        confmat = self.confusion_matrix(max_probs.squeeze(), target)
-        df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
-        plt.figure()
-        fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
-        # plt.close(fig_)
-        # self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch)
-        plt.savefig(f'{self.log_path}/cm_test')
-        plt.close(fig_)
-
+        self.log_confusion_matrix(probs, target, stage='test')
         #---->
         result = pd.DataFrame([metrics])
         result.to_csv(self.log_path / 'result.csv')
 
+    def configure_optimizers(self):
+        # optimizer_ft = optim.Adam(self.model_ft.parameters(), lr=self.optimizer.lr*0.1)
+        optimizer = create_optimizer(self.optimizer, self.model)
+        return optimizer     
+
 
     def load_model(self):
         name = self.hparams.model.name
@@ -350,6 +350,7 @@ class ModelInterface(pl.LightningModule):
         else:
             camel_name = name
         try:
+                
             Model = getattr(importlib.import_module(
                 f'models.{name}'), camel_name)
         except:
@@ -371,6 +372,20 @@ class ModelInterface(pl.LightningModule):
         args1.update(other_args)
         return Model(**args1)
 
+    def log_confusion_matrix(self, max_probs, target, stage):
+        confmat = self.confusion_matrix(max_probs.squeeze(), target)
+        df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
+        plt.figure()
+        fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
+        # plt.close(fig_)
+        # plt.savefig(f'{self.log_path}/cm_e{self.current_epoch}')
+        self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+
+        if stage == 'test':
+            plt.savefig(f'{self.log_path}/cm_test')
+        plt.close(fig_)
+        # self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch)
+
 class View(nn.Module):
     def __init__(self, shape):
         super().__init__()
@@ -383,4 +398,5 @@ class View(nn.Module):
         # batch_size = input.size(0)
         # shape = (batch_size, *self.shape)
         out = input.view(*self.shape)
-        return out
\ No newline at end of file
+        return out
+
diff --git a/models/resnet50.py b/models/resnet50.py
new file mode 100644
index 0000000000000000000000000000000000000000..89e23d7460df4b37d9455d9ab7a07d4350fd802f
--- /dev/null
+++ b/models/resnet50.py
@@ -0,0 +1,293 @@
+import torch
+import torch.nn as nn
+from .utils import load_state_dict_from_url
+
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+           'wide_resnet50_2', 'wide_resnet101_2']
+
+
+model_urls = {
+    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+                 base_width=64, dilation=1, norm_layer=None):
+        super(BasicBlock, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        if groups != 1 or base_width != 64:
+            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+        if dilation > 1:
+            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = norm_layer(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = norm_layer(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+    # This variant is also known as ResNet V1.5 and improves accuracy according to
+    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+                 base_width=64, dilation=1, norm_layer=None):
+        super(Bottleneck, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        width = int(planes * (base_width / 64.)) * groups
+        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+        self.conv1 = conv1x1(inplanes, width)
+        self.bn1 = norm_layer(width)
+        self.conv2 = conv3x3(width, width, stride, groups, dilation)
+        self.bn2 = norm_layer(width)
+        self.conv3 = conv1x1(width, planes * self.expansion)
+        self.bn3 = norm_layer(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class ResNet(nn.Module):
+
+    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
+                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
+                 norm_layer=None):
+        super(ResNet, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        self._norm_layer = norm_layer
+
+        self.inplanes = 64
+        self.dilation = 1
+        if replace_stride_with_dilation is None:
+            # each element in the tuple indicates if we should replace
+            # the 2x2 stride with a dilated convolution instead
+            replace_stride_with_dilation = [False, False, False]
+        if len(replace_stride_with_dilation) != 3:
+            raise ValueError("replace_stride_with_dilation should be None "
+                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+        self.groups = groups
+        self.base_width = width_per_group
+        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+                               bias=False)
+        self.bn1 = norm_layer(self.inplanes)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+                                       dilate=replace_stride_with_dilation[0])
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+                                       dilate=replace_stride_with_dilation[1])
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+                                       dilate=replace_stride_with_dilation[2])
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+        # Zero-initialize the last BN in each residual branch,
+        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, Bottleneck):
+                    nn.init.constant_(m.bn3.weight, 0)
+                elif isinstance(m, BasicBlock):
+                    nn.init.constant_(m.bn2.weight, 0)
+
+    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+        norm_layer = self._norm_layer
+        downsample = None
+        previous_dilation = self.dilation
+        if dilate:
+            self.dilation *= stride
+            stride = 1
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                norm_layer(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+                            self.base_width, previous_dilation, norm_layer))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(block(self.inplanes, planes, groups=self.groups,
+                                base_width=self.base_width, dilation=self.dilation,
+                                norm_layer=norm_layer))
+
+        return nn.Sequential(*layers)
+
+    def _forward_impl(self, x):
+        # See note [TorchScript super()]
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        x = torch.flatten(x, 1)
+        x = self.fc(x)
+
+        return x
+
+    def forward(self, x):
+        return self._forward_impl(x)
+
+
+def _resnet(arch, block, layers, pretrained, progress, **kwargs):
+    model = ResNet(block, layers, **kwargs)
+    if pretrained:
+        state_dict = load_state_dict_from_url(model_urls[arch],
+                                              progress=progress)
+        model.load_state_dict(state_dict)
+    return model
+
+
+def resnet18(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-18 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
+                   **kwargs)
+
+
+
+def resnet34(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-34 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+                   **kwargs)
+
+
+
+def resnet50(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-50 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+                   **kwargs)
+
+
+
+def resnet101(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-101 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
+                   **kwargs)
+
+
+
+def resnet152(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-152 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
+                   **kwargs)
\ No newline at end of file
diff --git a/train.py b/train.py
index 182e3c0f4f0df1096166441731e39794f8599766..036d5ed183d4c8ee047ff3413bfb3d4730154b19 100644
--- a/train.py
+++ b/train.py
@@ -3,6 +3,8 @@ from pathlib import Path
 import numpy as np
 import glob
 
+from sklearn.model_selection import KFold
+
 from datasets.data_interface import DataInterface, MILDataModule
 from models.model_interface import ModelInterface
 import models.vision_transformer as vits
@@ -11,25 +13,32 @@ from utils.utils import *
 # pytorch_lightning
 import pytorch_lightning as pl
 from pytorch_lightning import Trainer
+from pytorch_lightning.loops import KFoldLoop
+import torch
 
 #--->Setting parameters
 def make_parse():
     parser = argparse.ArgumentParser()
     parser.add_argument('--stage', default='train', type=str)
     parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str)
-    # parser.add_argument('--gpus', default = [2])
+    parser.add_argument('--gpus', default = 2, type=int)
+    parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str)
     parser.add_argument('--fold', default = 0)
+    parser.add_argument('--bag_size', default = 1024, type=int)
     args = parser.parse_args()
     return args
 
 #---->main
 def main(cfg):
 
+    torch.set_num_threads(16)
+
     #---->Initialize seed
     pl.seed_everything(cfg.General.seed)
 
     #---->load loggers
     cfg.load_loggers = load_loggers(cfg)
+    # print(cfg.load_loggers)
 
     #---->load callbacks
     cfg.callbacks = load_callbacks(cfg)
@@ -49,7 +58,10 @@ def main(cfg):
                 'batch_size': cfg.Data.train_dataloader.batch_size,
                 'num_workers': cfg.Data.train_dataloader.num_workers,
                 'n_classes': cfg.Model.n_classes,
+                'backbone': cfg.Model.backbone,
+                'bag_size': cfg.Data.bag_size,
                 }
+
     dm = MILDataModule(**DataInterface_dict)
     
 
@@ -70,9 +82,9 @@ def main(cfg):
         callbacks=cfg.callbacks,
         max_epochs= cfg.General.epochs,
         min_epochs = 200,
-        gpus=cfg.General.gpus,
-        # gpus = [4],
-        # strategy='ddp',
+        # gpus=cfg.General.gpus,
+        gpus = [2,3],
+        strategy='ddp',
         amp_backend='native',
         # amp_level=cfg.General.amp_level,  
         precision=cfg.General.precision,  
@@ -85,6 +97,7 @@ def main(cfg):
 
     #---->train or test
     if cfg.General.server == 'train':
+        trainer.fit_loop = KFoldLoop(3, trainer.fit_loop, )
         trainer.fit(model = model, datamodule = dm)
     else:
         model_paths = list(cfg.log_path.glob('*.ckpt'))
@@ -94,6 +107,7 @@ def main(cfg):
             new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg)
             trainer.test(model=new_model, datamodule=dm)
 
+
 if __name__ == '__main__':
 
     args = make_parse()
@@ -101,9 +115,12 @@ if __name__ == '__main__':
 
     #---->update
     cfg.config = args.config
-    # cfg.General.gpus = args.gpus
+    cfg.General.gpus = [args.gpus]
     cfg.General.server = args.stage
     cfg.Data.fold = args.fold
+    cfg.Loss.base_loss = args.loss
+    cfg.Data.bag_size = args.bag_size
+    
 
     #---->main
     main(cfg)
diff --git a/utils/__pycache__/utils.cpython-39.pyc b/utils/__pycache__/utils.cpython-39.pyc
index e1eeb958497b23b1f9e9934764083531720a66d0..704aade8ba3e8bcbedf237b3fde5db80b84fcd32 100644
Binary files a/utils/__pycache__/utils.cpython-39.pyc and b/utils/__pycache__/utils.cpython-39.pyc differ
diff --git a/utils/extract_features.py b/utils/extract_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb040a175099d6e6612a7634a10eee07c4345cde
--- /dev/null
+++ b/utils/extract_features.py
@@ -0,0 +1,27 @@
+## Choose Model and extract features from (augmented) image patches and save as .pt file
+
+from datasets.custom_dataloader import HDF5MILDataloader
+
+
+def extract_features(input_dir, output_dir, model, batch_size):
+
+
+    dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes)
+    if model == 'resnet50':
+        model = Resnet50_baseline(pretrained = True)       
+        model = model.to(device)
+        model.eval()
+
+
+
+if __name__ == '__main__':
+
+    # input_dir, output_dir
+    # initiate data loader
+    # use data loader to load and augment images 
+    # prediction from model
+    # choose save as bag or not (needed?)
+    
+    # features = torch.from_numpy(features)
+    # torch.save(features, output_path + '.pt') 
+
diff --git a/utils/utils.py b/utils/utils.py
index 96ed223d1f73fb9afbf951486376bc02c76ae75a..010eaab0c51e6fc784774502d85e7eef7303487e 100755
--- a/utils/utils.py
+++ b/utils/utils.py
@@ -1,4 +1,19 @@
 from pathlib import Path
+from abc import ABC, abstractclassmethod
+import torch
+import torchvision.transforms as T
+from sklearn.model_selection import KFold
+from torch.nn import functional as F
+from torch.utils.data import random_split
+from torch.utils.data.dataloader import DataLoader
+from torch.utils.data.dataset import Dataset, Subset
+from torchmetrics.classification.accuracy import Accuracy
+
+from pytorch_lightning import LightningDataModule, seed_everything, Trainer
+from pytorch_lightning.core.module import LightningModule
+from pytorch_lightning.loops.base import Loop
+from pytorch_lightning.loops.fit_loop import FitLoop
+from pytorch_lightning.trainer.states import TrainerFn
 
 #---->read yaml
 import yaml
@@ -14,19 +29,32 @@ def load_loggers(cfg):
 
     log_path = cfg.General.log_path
     Path(log_path).mkdir(exist_ok=True, parents=True)
-    log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}'
+    log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
     version_name = Path(cfg.config).name[:-5]
-    cfg.log_path = Path(log_path) / log_name / version_name / f'fold{cfg.Data.fold}'
-    print(f'---->Log dir: {cfg.log_path}')
+    
     
     #---->TensorBoard
-    tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name),
-                                             name = version_name, version = f'fold{cfg.Data.fold}',
-                                             log_graph = True, default_hp_metric = False)
-    #---->CSV
-    csv_logger = pl_loggers.CSVLogger(log_path+str(log_name),
-                                      name = version_name, version = f'fold{cfg.Data.fold}', )
+    if cfg.stage != 'test':
+        cfg.log_path = Path(log_path) / log_name / version_name / f'fold{cfg.Data.fold}'
+        tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name),
+                                                name = version_name, version = f'fold{cfg.Data.fold}',
+                                                log_graph = True, default_hp_metric = False)
+        #---->CSV
+        csv_logger = pl_loggers.CSVLogger(log_path+str(log_name),
+                                        name = version_name, version = f'fold{cfg.Data.fold}', )
+    else:  
+        cfg.log_path = Path(log_path) / log_name / version_name / f'test'
+        tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name),
+                                                name = version_name, version = f'test',
+                                                log_graph = True, default_hp_metric = False)
+        #---->CSV
+        csv_logger = pl_loggers.CSVLogger(log_path+str(log_name),
+                                        name = version_name, version = f'test', )
+                                    
     
+    print(f'---->Log dir: {cfg.log_path}')
+
+    # return tb_logger
     return [tb_logger, csv_logger]
 
 
@@ -74,6 +102,14 @@ def load_callbacks(cfg):
                                          save_top_k = 1,
                                          mode = 'min',
                                          save_weights_only = True))
+        Mycallbacks.append(ModelCheckpoint(monitor = 'val_auc',
+                                         dirpath = str(cfg.log_path),
+                                         filename = '{epoch:02d}-{val_auc:.4f}',
+                                         verbose = True,
+                                         save_last = True,
+                                         save_top_k = 1,
+                                         mode = 'max',
+                                         save_weights_only = True))
     return Mycallbacks
 
 #---->val loss
@@ -84,3 +120,103 @@ def cross_entropy_torch(x, y):
     x_log = torch.tensor([torch.log(x_softmax[i][y[i]]) for i in range(y.shape[0])])
     loss = - torch.sum(x_log) / y.shape[0]
     return loss
+
+#-----> convert labels for task
+label_map = {
+    'bin': {'0': 0, '1': 1, '2': 1, '3': 1, '4': 1, '5': None},
+    'tcmr_viral': {'0': None, '1': 0, '2': None, '3': None, '4': 1, '5': None},
+    'no_viral': {'0': 0, '1': 1, '2': 2, '3': 3, '4': None, '5': None},
+    'no_other': {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': None},
+    'no_stable': {'0': None, '1': 1, '2': 2, '3': 3, '4': None, '5': None},
+    'all': {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5},
+
+}
+def convert_labels_for_task(task, label):
+
+    return label_map[task][label]
+
+
+#-----> KFOLD LOOP
+
+class KFoldLoop(Loop):
+    def __init__(self, num_folds: int, export_path: str) -> None:
+        super().__init__()
+        self.num_folds = num_folds
+        self.current_fold: int = 0
+        self.export_path = export_path
+
+    @property
+    def done(self) -> bool:
+        return self.current_fold >= self.num_folds
+
+    def connect(self, fit_loop: FitLoop) -> None:
+        self.fit_loop = fit_loop
+
+    def reset(self) -> None:
+        """Nothing to reset in this loop."""
+
+    def on_run_start(self, *args: Any, **kwargs: Any) -> None:
+        """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the
+        model."""
+        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
+        self.trainer.datamodule.setup_folds(self.num_folds)
+        self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict())
+
+    def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
+        """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance."""
+        print(f"STARTING FOLD {self.current_fold}")
+        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
+        self.trainer.datamodule.setup_fold_index(self.current_fold)
+
+    def advance(self, *args: Any, **kwargs: Any) -> None:
+        """Used to the run a fitting and testing on the current hold."""
+        self._reset_fitting()  # requires to reset the tracking stage.
+        self.fit_loop.run()
+
+        self._reset_testing()  # requires to reset the tracking stage.
+        self.trainer.test_loop.run()
+        self.current_fold += 1  # increment fold tracking number.
+
+    def on_advance_end(self) -> None:
+        """Used to save the weights of the current fold and reset the LightningModule and its optimizers."""
+        self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt"))
+        # restore the original weights + optimizers and schedulers.
+        self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict)
+        self.trainer.strategy.setup_optimizers(self.trainer)
+        self.replace(fit_loop=FitLoop)
+
+    def on_run_end(self) -> None:
+        """Used to compute the performance of the ensemble model on the test set."""
+        checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)]
+        voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths)
+        voting_model.trainer = self.trainer
+        # This requires to connect the new model and move it the right device.
+        self.trainer.strategy.connect(voting_model)
+        self.trainer.strategy.model_to_device()
+        self.trainer.test_loop.run()
+
+    def on_save_checkpoint(self) -> Dict[str, int]:
+        return {"current_fold": self.current_fold}
+
+    def on_load_checkpoint(self, state_dict: Dict) -> None:
+        self.current_fold = state_dict["current_fold"]
+
+    def _reset_fitting(self) -> None:
+        self.trainer.reset_train_dataloader()
+        self.trainer.reset_val_dataloader()
+        self.trainer.state.fn = TrainerFn.FITTING
+        self.trainer.training = True
+
+    def _reset_testing(self) -> None:
+        self.trainer.reset_test_dataloader()
+        self.trainer.state.fn = TrainerFn.TESTING
+        self.trainer.testing = True
+
+    def __getattr__(self, key) -> Any:
+        # requires to be overridden as attributes of the wrapped loop are being accessed.
+        if key not in self.__dict__:
+            return getattr(self.fit_loop, key)
+        return self.__dict__[key]
+
+    def __setstate__(self, state: Dict[str, Any]) -> None:
+        self.__dict__.update(state)
\ No newline at end of file