diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..4cf8dd15619e7c11d325ae0eb80bba874a99f06d
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+logs/*
\ No newline at end of file
diff --git a/Camelyon/TransMIL.yaml b/Camelyon/TransMIL.yaml
index 6642b702777078a7c84261101baa1d3844445847..fba514cc00e181fe48a698215a5de1b68034f6c2 100644
--- a/Camelyon/TransMIL.yaml
+++ b/Camelyon/TransMIL.yaml
@@ -29,9 +29,13 @@ Data:
         batch_size: 1
         num_workers: 8
 
+        
+
+
 Model:
     name: TransMIL
     n_classes: 2
+    backbone: resnet18
 
 
 Optimizer:
diff --git a/DeepGraft/TransMIL.yaml b/DeepGraft/TransMIL.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7945828389d87697cff095a2fcc4c5a062437f90
--- /dev/null
+++ b/DeepGraft/TransMIL.yaml
@@ -0,0 +1,50 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 200 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 200
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_bin.json'
+    fold: 0
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+        
+
+Model:
+    name: TransMIL
+    n_classes: 6
+    backbone: resnet18
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_dino.yaml b/DeepGraft/TransMIL_dino.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ffe987c913ff9075730d7b16b8e9dba16d5d4978
--- /dev/null
+++ b/DeepGraft/TransMIL_dino.yaml
@@ -0,0 +1,50 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 1000 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 200
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_bin.json'
+    fold: 0
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+        
+
+Model:
+    name: TransMIL
+    n_classes: 2
+    backbone: dino
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_resnet18_all.yaml b/DeepGraft/TransMIL_resnet18_all.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8fa5818981b31c1a6c36f34b42e055f2198681fc
--- /dev/null
+++ b/DeepGraft/TransMIL_resnet18_all.yaml
@@ -0,0 +1,48 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [3]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 200
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_all.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 6
+    backbone: resnet18
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_resnet18_no_other.yaml b/DeepGraft/TransMIL_resnet18_no_other.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..95a9bd64692f5ee12dd822e04fea889b80717457
--- /dev/null
+++ b/DeepGraft/TransMIL_resnet18_no_other.yaml
@@ -0,0 +1,48 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [4]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 200
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_other.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 5
+    backbone: resnet18
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_resnet18_no_viral.yaml b/DeepGraft/TransMIL_resnet18_no_viral.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..155b676a24e0541f42f8fee12af2d987217c0525
--- /dev/null
+++ b/DeepGraft/TransMIL_resnet18_no_viral.yaml
@@ -0,0 +1,48 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [3]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 200
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 4
+    backbone: resnet18
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e7d9bf0694a227f987d1d2fbf2f0facb53c248d5
--- /dev/null
+++ b/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml
@@ -0,0 +1,48 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [3]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 200
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 2
+    backbone: resnet18
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_resnet50.yaml b/DeepGraft/TransMIL_resnet50.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d76b2cf618dc0f3f249f9aac787eb471793bd49f
--- /dev/null
+++ b/DeepGraft/TransMIL_resnet50.yaml
@@ -0,0 +1,50 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 32 
+    multi_gpu_mode: dp
+    gpus: [1]
+    epochs: &epoch 200 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 200
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_bin.json'
+    fold: 0
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+        
+
+Model:
+    name: TransMIL
+    n_classes: 2
+    backbone: resnet50
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_resnet50_all.yaml b/DeepGraft/TransMIL_resnet50_all.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..eba3a4fa20870ad4fc2b173ccb4e60086ddb3ac5
--- /dev/null
+++ b/DeepGraft/TransMIL_resnet50_all.yaml
@@ -0,0 +1,50 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [1]
+    epochs: &epoch 1000 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 200
+    server: train #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_all.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+        
+
+Model:
+    name: TransMIL
+    n_classes: 6
+    backbone: resnet50
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0001
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_simple.yaml b/DeepGraft/TransMIL_simple.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4501f2c06242bb6cb951fef458eab000daac0d75
--- /dev/null
+++ b/DeepGraft/TransMIL_simple.yaml
@@ -0,0 +1,50 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [1]
+    epochs: &epoch 200 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 200
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_bin.json'
+    fold: 0
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+        
+
+Model:
+    name: TransMIL
+    n_classes: 2
+    backbone: simple
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/MyLoss/__pycache__/ND_Crossentropy.cpython-39.pyc b/MyLoss/__pycache__/ND_Crossentropy.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b39de626d6dab873f33584529f0c3a4bef6bf961
Binary files /dev/null and b/MyLoss/__pycache__/ND_Crossentropy.cpython-39.pyc differ
diff --git a/MyLoss/__pycache__/__init__.cpython-39.pyc b/MyLoss/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..119cc4d4a02cb0c100c171e5e23385ccb8309724
Binary files /dev/null and b/MyLoss/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MyLoss/__pycache__/boundary_loss.cpython-39.pyc b/MyLoss/__pycache__/boundary_loss.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..34040291007bdff88adfcea8434647cf1d47f3a2
Binary files /dev/null and b/MyLoss/__pycache__/boundary_loss.cpython-39.pyc differ
diff --git a/MyLoss/__pycache__/dice_loss.cpython-39.pyc b/MyLoss/__pycache__/dice_loss.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..23837a13153c252543d3e48776b1fbdea5eff0a2
Binary files /dev/null and b/MyLoss/__pycache__/dice_loss.cpython-39.pyc differ
diff --git a/MyLoss/__pycache__/focal_loss.cpython-39.pyc b/MyLoss/__pycache__/focal_loss.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b2734f4edcae0885494fb2002e9678f127ccc8d
Binary files /dev/null and b/MyLoss/__pycache__/focal_loss.cpython-39.pyc differ
diff --git a/MyLoss/__pycache__/hausdorff.cpython-39.pyc b/MyLoss/__pycache__/hausdorff.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4472c5317ce1e97640f7705acdd550d6ff75b161
Binary files /dev/null and b/MyLoss/__pycache__/hausdorff.cpython-39.pyc differ
diff --git a/MyLoss/__pycache__/loss_factory.cpython-39.pyc b/MyLoss/__pycache__/loss_factory.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ed7437016bbcadc48f9bb0a97401529dc574c9b2
Binary files /dev/null and b/MyLoss/__pycache__/loss_factory.cpython-39.pyc differ
diff --git a/MyLoss/__pycache__/lovasz_loss.cpython-39.pyc b/MyLoss/__pycache__/lovasz_loss.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b50bae1fae55772838d508428f39b9eb9d4ca90b
Binary files /dev/null and b/MyLoss/__pycache__/lovasz_loss.cpython-39.pyc differ
diff --git a/MyLoss/loss_factory.py b/MyLoss/loss_factory.py
index 2394abe78706f535cf7e20da871031386b72005a..1dffa6182ef40b1f46436b8a1eadb0a8906c17ff 100755
--- a/MyLoss/loss_factory.py
+++ b/MyLoss/loss_factory.py
@@ -34,8 +34,6 @@ def create_loss(args, w1=1.0, w2=0.5):
         loss = L.BinaryDiceLoss()
     elif conf_loss == "dice_log":
         loss = L.BinaryDiceLogLoss()
-    elif conf_loss == "dice_log":
-        loss = L.BinaryDiceLogLoss()
     elif conf_loss == "bce+lovasz":
         loss = L.JointLoss(BCEWithLogitsLoss(), L.BinaryLovaszLoss(), w1, w2)
     elif conf_loss == "lovasz":
@@ -62,6 +60,7 @@ def make_parse():
 if __name__ == '__main__':
     args = make_parse()
     myloss = create_loss(args)
+    print(myloss)
     data = torch.randn(2, 3)
     label = torch.empty(2, dtype=torch.long).random_(3)
     loss = myloss(data, label)
\ No newline at end of file
diff --git a/MyOptimizer/__pycache__/__init__.cpython-39.pyc b/MyOptimizer/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..058b9be16d0a0efced0d6e4650eb2c5eb3a2450b
Binary files /dev/null and b/MyOptimizer/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MyOptimizer/__pycache__/adafactor.cpython-39.pyc b/MyOptimizer/__pycache__/adafactor.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5594d367f0f462a2976347a368a45da43954538c
Binary files /dev/null and b/MyOptimizer/__pycache__/adafactor.cpython-39.pyc differ
diff --git a/MyOptimizer/__pycache__/adahessian.cpython-39.pyc b/MyOptimizer/__pycache__/adahessian.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bc22c127ae8d56a626d590c23499817208854192
Binary files /dev/null and b/MyOptimizer/__pycache__/adahessian.cpython-39.pyc differ
diff --git a/MyOptimizer/__pycache__/adamp.cpython-39.pyc b/MyOptimizer/__pycache__/adamp.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ba6bf0c26ff7652d56f9139e6b93a9ef69f58aa
Binary files /dev/null and b/MyOptimizer/__pycache__/adamp.cpython-39.pyc differ
diff --git a/MyOptimizer/__pycache__/adamw.cpython-39.pyc b/MyOptimizer/__pycache__/adamw.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cae0384996e05f7a7498980d9d36881eb3625d2e
Binary files /dev/null and b/MyOptimizer/__pycache__/adamw.cpython-39.pyc differ
diff --git a/MyOptimizer/__pycache__/lookahead.cpython-39.pyc b/MyOptimizer/__pycache__/lookahead.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ae677bc3ab72ad3f7a4591ee097c38fb2755106
Binary files /dev/null and b/MyOptimizer/__pycache__/lookahead.cpython-39.pyc differ
diff --git a/MyOptimizer/__pycache__/nadam.cpython-39.pyc b/MyOptimizer/__pycache__/nadam.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8bcbfa6a76f3d6e00685eef8c40c0840e26ae884
Binary files /dev/null and b/MyOptimizer/__pycache__/nadam.cpython-39.pyc differ
diff --git a/MyOptimizer/__pycache__/novograd.cpython-39.pyc b/MyOptimizer/__pycache__/novograd.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c9f49e4a9cd08791de1aafb9c381117c17a9d9ee
Binary files /dev/null and b/MyOptimizer/__pycache__/novograd.cpython-39.pyc differ
diff --git a/MyOptimizer/__pycache__/nvnovograd.cpython-39.pyc b/MyOptimizer/__pycache__/nvnovograd.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3528945920d301895ec14ba380f4b1132b62478e
Binary files /dev/null and b/MyOptimizer/__pycache__/nvnovograd.cpython-39.pyc differ
diff --git a/MyOptimizer/__pycache__/optim_factory.cpython-39.pyc b/MyOptimizer/__pycache__/optim_factory.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c6bfac29b2f804c5be7a3f18b97fe6d9401a3cdc
Binary files /dev/null and b/MyOptimizer/__pycache__/optim_factory.cpython-39.pyc differ
diff --git a/MyOptimizer/__pycache__/radam.cpython-39.pyc b/MyOptimizer/__pycache__/radam.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..11a200d1a53e2fa7f59e54daaae3f30c6043ce52
Binary files /dev/null and b/MyOptimizer/__pycache__/radam.cpython-39.pyc differ
diff --git a/MyOptimizer/__pycache__/rmsprop_tf.cpython-39.pyc b/MyOptimizer/__pycache__/rmsprop_tf.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70fc2eb4e1c4a2ff4658ffaa59558f13219a5a73
Binary files /dev/null and b/MyOptimizer/__pycache__/rmsprop_tf.cpython-39.pyc differ
diff --git a/MyOptimizer/__pycache__/sgdp.cpython-39.pyc b/MyOptimizer/__pycache__/sgdp.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b747bafd446e4ab8cc65806fbf09cb1f38430e4
Binary files /dev/null and b/MyOptimizer/__pycache__/sgdp.cpython-39.pyc differ
diff --git a/MyOptimizer/lookahead.py b/MyOptimizer/lookahead.py
index 6b5b7f38ec8cb6594e3986b66223fa2881daeca3..b8e8b0095e8da7a28b8742df9e0322996941bb74 100755
--- a/MyOptimizer/lookahead.py
+++ b/MyOptimizer/lookahead.py
@@ -35,7 +35,9 @@ class Lookahead(Optimizer):
                 param_state['slow_buffer'] = torch.empty_like(fast_p.data)
                 param_state['slow_buffer'].copy_(fast_p.data)
             slow = param_state['slow_buffer']
-            slow.add_(group['lookahead_alpha'], fast_p.data - slow)
+            # slow.add_(group['lookahead_alpha'], fast_p.data - slow)
+            slow.add_(fast_p.data-slow, alpha=group['lookahead_alpha'])
+
             fast_p.data.copy_(slow)
 
     def sync_lookahead(self):
diff --git a/MyOptimizer/optim_factory.py b/MyOptimizer/optim_factory.py
index ce310e3f593680b579369b51d047a681b41ce351..992231aab94e896725de99e636595e3b0ce2ebe7 100755
--- a/MyOptimizer/optim_factory.py
+++ b/MyOptimizer/optim_factory.py
@@ -75,7 +75,8 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
     elif opt_lower == 'nadam':
         optimizer = Nadam(parameters, **opt_args)
     elif opt_lower == 'radam':
-        optimizer = RAdam(parameters, **opt_args)
+        # optimizer = RAdam(parameters, **opt_args)
+        optimizer = optim.RAdam(parameters, **opt_args)
     elif opt_lower == 'adamp':        
         optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
     elif opt_lower == 'sgdp':        
diff --git a/datasets/__pycache__/__init__.cpython-39.pyc b/datasets/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13006df55083dc070f46953e4505bfef1ac8b198
Binary files /dev/null and b/datasets/__pycache__/__init__.cpython-39.pyc differ
diff --git a/datasets/__pycache__/camel_data.cpython-39.pyc b/datasets/__pycache__/camel_data.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8ffc96a2ce0ccb33cabde901455fd1b7c9c44811
Binary files /dev/null and b/datasets/__pycache__/camel_data.cpython-39.pyc differ
diff --git a/datasets/__pycache__/camel_dataloader.cpython-39.pyc b/datasets/__pycache__/camel_dataloader.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..11da2c2e9cc3f2bb97782008855d83b10df237c4
Binary files /dev/null and b/datasets/__pycache__/camel_dataloader.cpython-39.pyc differ
diff --git a/datasets/__pycache__/custom_dataloader.cpython-39.pyc b/datasets/__pycache__/custom_dataloader.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..147b3fc3c4628d6eed5bebd6ff248d31aaeebb34
Binary files /dev/null and b/datasets/__pycache__/custom_dataloader.cpython-39.pyc differ
diff --git a/datasets/__pycache__/data_interface.cpython-39.pyc b/datasets/__pycache__/data_interface.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4af0141496c11d6f97255add8746fa0067e55636
Binary files /dev/null and b/datasets/__pycache__/data_interface.cpython-39.pyc differ
diff --git a/datasets/camel_dataloader.py b/datasets/camel_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..302cabd491bf81619af8c11692ae560ea81bf410
--- /dev/null
+++ b/datasets/camel_dataloader.py
@@ -0,0 +1,126 @@
+import pandas as pd
+
+import numpy as np
+import torch
+from torch import Tensor
+from torch.autograd import Variable
+from torch.nn.functional import one_hot
+import torch.utils.data as data_utils
+from torchvision import datasets, transforms
+import pandas as pd
+from sklearn.utils import shuffle
+from pathlib import Path
+from tqdm import tqdm
+
+
+class FeatureBagLoader(data_utils.Dataset):
+    def __init__(self, data_root,train=True, cache=True):
+
+        bags_path = pd.read_csv(data_root)
+
+        self.train_path = bags_path.iloc[0:int(len(bags_path)*0.8), :]
+        self.test_path = bags_path.iloc[int(len(bags_path)*0.8):, :]
+        # self.train_path = shuffle(train_path).reset_index(drop=True)
+        # self.test_path = shuffle(test_path).reset_index(drop=True)
+
+        home = Path.cwd().parts[1]
+        self.origin_path = Path(f'/{home}/ylan/RCC_project/rcc_classification/')
+        # self.target_number = target_number
+        # self.mean_bag_length = mean_bag_length
+        # self.var_bag_length = var_bag_length
+        # self.num_bag = num_bag
+        self.cache = cache
+        self.train = train
+        self.n_classes = 2
+
+        self.features = []
+        self.labels = []
+        if self.cache:
+            if train:
+                with tqdm(total=len(self.train_path)) as pbar:
+                    for t in tqdm(self.train_path.iloc()):
+                        ft, lbl = self.get_bag_feats(t)
+                        # ft = ft.view(-1, 512)
+                        
+                        self.labels.append(lbl)
+                        self.features.append(ft)
+                        pbar.update()
+            else: 
+                with tqdm(total=len(self.test_path)) as pbar:
+                    for t in tqdm(self.test_path.iloc()):
+                        ft, lbl = self.get_bag_feats(t)
+                        # lbl = Variable(Tensor(lbl))
+                        # ft = Variable(Tensor(ft)).view(-1, 512)
+                        self.labels.append(lbl)
+                        self.features.append(ft)
+                        pbar.update()
+        # print(self.get_bag_feats(self.train_path))
+        # self.r = np.random.RandomState(seed)
+
+        # self.num_in_train = 60000
+        # self.num_in_test = 10000
+
+        # if self.train:
+        #     self.train_bags_list, self.train_labels_list = self._create_bags()
+        # else:
+        #     self.test_bags_list, self.test_labels_list = self._create_bags()
+
+    def get_bag_feats(self, csv_file_df):
+        # if args.dataset == 'TCGA-lung-default':
+        #     feats_csv_path = 'datasets/tcga-dataset/tcga_lung_data_feats/' + csv_file_df.iloc[0].split('/')[1] + '.csv'
+        # else:
+        
+        feats_csv_path = self.origin_path / csv_file_df.iloc[0]
+        df = pd.read_csv(feats_csv_path)
+        # feats = shuffle(df).reset_index(drop=True)
+        # feats = feats.to_numpy()
+        feats = df.to_numpy()
+        label = np.zeros(self.n_classes)
+        if self.n_classes==2:
+            label[1] = csv_file_df.iloc[1]
+        else:
+            if int(csv_file_df.iloc[1])<=(len(label)-1):
+                label[int(csv_file_df.iloc[1])] = 1
+        
+        return feats, label
+
+    def __len__(self):
+        if self.train:
+            return len(self.train_path)
+        else:
+            return len(self.test_path)
+
+    def __getitem__(self, index):
+
+        if self.cache:
+            label = self.labels[index]
+            feats = self.features[index]
+            label = Variable(Tensor(label))
+            feats = Variable(Tensor(feats)).view(-1, 512)
+            return feats, label
+        else:
+            if self.train:
+                feats, label = self.get_bag_feats(self.train_path.iloc[index])
+                label = Variable(Tensor(label))
+                feats = Variable(Tensor(feats)).view(-1, 512)
+            else:
+                feats, label = self.get_bag_feats(self.test_path.iloc[index])
+                label = Variable(Tensor(label))
+                feats = Variable(Tensor(feats)).view(-1, 512)
+
+            return feats, label
+
+if __name__ == '__main__':
+    import os
+    cwd = os.getcwd()
+    home = cwd.split('/')[1]
+    data_root = f'/{home}/ylan/RCC_project/rcc_classification/datasets/Camelyon16/Camelyon16.csv'
+    dataset = FeatureBagLoader(data_root, cache=False)
+    for i in dataset: 
+        # print(i[1])
+        # print(i)
+        
+        features, label = i
+        print(label)
+        # print(features.shape)
+        # print(label[0].long())
\ No newline at end of file
diff --git a/datasets/custom_dataloader.py b/datasets/custom_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddc2ed36c64fc46f0844673514b51cead0052bc6
--- /dev/null
+++ b/datasets/custom_dataloader.py
@@ -0,0 +1,332 @@
+import h5py
+# import helpers
+import numpy as np
+from pathlib import Path
+import torch
+# from torch._C import long
+from torch.utils import data
+from torch.utils.data.dataloader import DataLoader
+from tqdm import tqdm
+# from histoTransforms import RandomHueSaturationValue
+import torchvision.transforms as transforms
+import torch.nn.functional as F
+import csv
+from PIL import Image
+import cv2
+import pandas as pd
+import json
+
+class HDF5MILDataloader(data.Dataset):
+    """Represents an abstract HDF5 dataset. For single H5 container! 
+    
+    Input params:
+        file_path: Path to the folder containing the dataset (one or multiple HDF5 files).
+        mode: 'train' or 'test'
+        load_data: If True, loads all the data immediately into RAM. Use this if
+            the dataset is fits into memory. Otherwise, leave this at false and 
+            the data will load lazily.
+        data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
+
+    """
+    def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=20):
+        super().__init__()
+
+        self.data_info = []
+        self.data_cache = {}
+        self.slideLabelDict = {}
+        self.data_cache_size = data_cache_size
+        self.mode = mode
+        self.file_path = file_path
+        # self.csv_path = csv_path
+        self.label_path = label_path
+        self.n_classes = n_classes
+        self.bag_size = 120
+        # self.label_file = label_path
+        recursive = True
+
+        # read labels and slide_path from csv
+
+        # df = pd.read_csv(self.csv_path)
+        # labels = df.LABEL
+        # slides = df.FILENAME
+        with open(self.label_path, 'r') as f:
+            self.slideLabelDict = json.load(f)[mode]
+
+        self.slideLabelDict = {Path(x).stem : y for (x,y) in self.slideLabelDict}
+
+            
+        # if Path(slides[0]).suffix:
+        #     slides = list(map(lambda x: Path(x).stem, slides))
+
+        # print(labels)
+        # print(slides)
+        # self.slideLabelDict = dict(zip(slides, labels))
+        # print(self.slideLabelDict)
+
+        #check if files in slideLabelDict, only take files that are available.
+
+        files_in_path = list(Path(self.file_path).rglob('*.hdf5'))
+        files_in_path = [x.stem for x in files_in_path]
+        # print(len(files_in_path))
+        # print(files_in_path)
+        # print(list(self.slideLabelDict.keys()))
+        # for x in list(self.slideLabelDict.keys()):
+        #     if x in files_in_path:
+        #         path = Path(self.file_path) / (x + '.hdf5')
+        #         print(path)
+
+        self.files = [Path(self.file_path)/ (x + '.hdf5') for x in list(self.slideLabelDict.keys()) if x in files_in_path]
+
+        print(len(self.files))
+        # self.files = list(map(lambda x: Path(self.file_path) / (Path(x).stem + '.hdf5'), list(self.slideLabelDict.keys())))
+
+        for h5dataset_fp in tqdm(self.files):
+            # print(h5dataset_fp)
+            self._add_data_infos(str(h5dataset_fp.resolve()), load_data)
+
+        # print(self.data_info)
+        self.resize_transforms = transforms.Compose([
+            transforms.ToPILImage(),
+            transforms.Resize(256),
+        ])
+
+        self.img_transforms = transforms.Compose([    
+            transforms.RandomHorizontalFlip(p=1),
+            transforms.RandomVerticalFlip(p=1),
+            # histoTransforms.AutoRandomRotation(),
+            transforms.Lambda(lambda a: np.array(a)),
+        ]) 
+        self.hsv_transforms = transforms.Compose([
+            RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)),
+            transforms.ToTensor()
+        ])
+
+        # self._add_data_infos(load_data)
+
+
+    def __getitem__(self, index):
+        # get data
+        batch, label, name = self.get_data(index)
+        out_batch = []
+        
+        if self.mode == 'train':
+            # print(img)
+            # print(img.shape)
+            for img in batch: 
+                img = self.img_transforms(img)
+                img = self.hsv_transforms(img)
+                out_batch.append(img)
+
+        else:
+            for img in batch:
+                img = transforms.functional.to_tensor(img)
+                out_batch.append(img)
+        if len(out_batch) == 0:
+            # print(name)
+            out_batch = torch.randn(100,3,256,256)
+        else: out_batch = torch.stack(out_batch)
+        out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch
+
+        label = torch.as_tensor(label)
+        label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+        return out_batch, label, name
+
+    def __len__(self):
+        return len(self.data_info)
+    
+    def _add_data_infos(self, file_path, load_data):
+        wsi_name = Path(file_path).stem
+        if wsi_name in self.slideLabelDict:
+            label = self.slideLabelDict[wsi_name]
+            wsi_batch = []
+            # with h5py.File(file_path, 'r') as h5_file:
+            #     numKeys = len(h5_file.keys())
+            #     sample = list(h5_file.keys())[0]
+            #     shape = (numKeys,) + h5_file[sample][:].shape
+                # for tile in h5_file.keys():
+                #     img = h5_file[tile][:]
+                    
+                    # print(img)
+                    # if type == 'images':
+                    #     t = 'data'
+                    # else: 
+                    #     t = 'label'
+            idx = -1
+            # if load_data: 
+            #     for tile in h5_file.keys():
+            #         img = h5_file[tile][:]
+            #         img = img.astype(np.uint8)
+            #         img = self.resize_transforms(img)
+            #         wsi_batch.append(img)
+            #     idx = self._add_to_cache(wsi_batch, file_path)
+                #     wsi_batch.append(img)
+            # self.data_info.append({'data_path': file_path, 'label': label, 'shape': shape, 'name': wsi_name, 'cache_idx': idx})
+            self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx})
+
+    def _load_data(self, file_path):
+        """Load data to the cache given the file
+        path and update the cache index in the
+        data_info structure.
+        """
+        with h5py.File(file_path, 'r') as h5_file:
+            wsi_batch = []
+            for tile in h5_file.keys():
+                img = h5_file[tile][:]
+                img = img.astype(np.uint8)
+                img = self.resize_transforms(img)
+                wsi_batch.append(img)
+            idx = self._add_to_cache(wsi_batch, file_path)
+            file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path)
+            self.data_info[file_idx + idx]['cache_idx'] = idx
+
+            # for type in ['images', 'labels']:
+            #     for key in tqdm(h5_file[f'{self.mode}/{type}'].keys()):
+            #         img = h5_file[data_path][:]
+            #         idx = self._add_to_cache(img, data_path)
+            #         file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == data_path)
+            #         self.data_info[file_idx + idx]['cache_idx'] = idx
+            # for gname, group in h5_file.items():
+            #     for dname, ds in group.items():
+            #         # add data to the data cache and retrieve
+            #         # the cache index
+            #         idx = self._add_to_cache(ds.value, file_path)
+
+            #         # find the beginning index of the hdf5 file we are looking for
+            #         file_idx = next(i for i,v in enumerate(self.data_info) if v['file_path'] == file_path)
+
+            #         # the data info should have the same index since we loaded it in the same way
+            #         self.data_info[file_idx + idx]['cache_idx'] = idx
+
+        # remove an element from data cache if size was exceeded
+        if len(self.data_cache) > self.data_cache_size:
+            # remove one item from the cache at random
+            removal_keys = list(self.data_cache)
+            removal_keys.remove(file_path)
+            self.data_cache.pop(removal_keys[0])
+            # remove invalid cache_idx
+            # self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'shape': di['shape'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
+            self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
+
+    def _add_to_cache(self, data, data_path):
+        """Adds data to the cache and returns its index. There is one cache
+        list for every file_path, containing all datasets in that file.
+        """
+        if data_path not in self.data_cache:
+            self.data_cache[data_path] = [data]
+        else:
+            self.data_cache[data_path].append(data)
+        return len(self.data_cache[data_path]) - 1
+
+    # def get_data_infos(self, type):
+    #     """Get data infos belonging to a certain type of data.
+    #     """
+    #     data_info_type = [di for di in self.data_info if di['type'] == type]
+    #     return data_info_type
+
+    def get_name(self, i):
+        # name = self.get_data_infos(type)[i]['name']
+        name = self.data_info[i]['name']
+        return name
+
+    def get_data(self, i):
+        """Call this function anytime you want to access a chunk of data from the
+            dataset. This will make sure that the data is loaded in case it is
+            not part of the data cache.
+            i = index
+        """
+        # fp = self.get_data_infos(type)[i]['data_path']
+        fp = self.data_info[i]['data_path']
+        if fp not in self.data_cache:
+            self._load_data(fp)
+        
+        # get new cache_idx assigned by _load_data_info
+        # cache_idx = self.get_data_infos(type)[i]['cache_idx']
+        cache_idx = self.data_info[i]['cache_idx']
+        label = self.data_info[i]['label']
+        name = self.data_info[i]['name']
+        # print(self.data_cache[fp][cache_idx])
+        return self.data_cache[fp][cache_idx], label, name
+
+
+class RandomHueSaturationValue(object):
+
+    def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5):
+        
+        self.hue_shift_limit = hue_shift_limit
+        self.sat_shift_limit = sat_shift_limit
+        self.val_shift_limit = val_shift_limit
+        self.p = p
+
+    def __call__(self, sample):
+    
+        img = sample #,lbl
+    
+        if np.random.random() < self.p:
+            img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32
+            h, s, v = cv2.split(img)
+            hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1)
+            hue_shift = np.uint8(hue_shift)
+            h += hue_shift
+            sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1])
+            s = cv2.add(s, sat_shift)
+            val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1])
+            v = cv2.add(v, val_shift)
+            img = cv2.merge((h, s, v))
+            img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
+        return img #, lbl
+
+
+
+if __name__ == '__main__':
+    from pathlib import Path
+    import os
+
+    home = Path.cwd().parts[1]
+    train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
+    data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/split_Aachen_PAS_all.json'
+    output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/'
+    # os.makedirs(output_path, exist_ok=True)
+
+
+    dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=6)
+    data = DataLoader(dataset, batch_size=1)
+
+    # print(len(dataset))
+    x = 0
+    c = 0
+    for item in data: 
+        if c >=10:
+            break
+        bag, label, name = item
+        print(bag)
+        # # print(bag.shape)
+        # if bag.shape[1] == 1:
+        #     print(name)
+        #     print(bag.shape)
+        # print(bag.shape)
+        # print(name)
+        # out_dir = Path(output_path) / name
+        # os.makedirs(out_dir, exist_ok=True)
+
+        # # print(item[2])
+        # # print(len(item))
+        # # print(item[1])
+        # # print(data.shape)
+        # # data = data.squeeze()
+        # bag = item[0]
+        # bag = bag.squeeze()
+        # for i in range(bag.shape[0]):
+        #     img = bag[i, :, :, :]
+        #     img = img.squeeze()
+        #     img = img*255
+        #     img = img.numpy().astype(np.uint8).transpose(1,2,0)
+            
+        #     img = Image.fromarray(img)
+        #     img = img.convert('RGB')
+        #     img.save(f'{out_dir}/{i}.png')
+        c += 1
+        # else: break
+        # print(data.shape)
+        # print(label)
\ No newline at end of file
diff --git a/datasets/data_interface.py b/datasets/data_interface.py
index 3952e5bae7d77d2f67f6e417906e7accf9c00517..12a0f8c450b4945658d3566cc5c157116bccab66 100644
--- a/datasets/data_interface.py
+++ b/datasets/data_interface.py
@@ -1,9 +1,13 @@
 import inspect # 查看python 类的参数和模块、函数代码
 import importlib # In order to dynamically import the library
+from typing import Optional
 import pytorch_lightning as pl
 from torch.utils.data import random_split, DataLoader
 from torchvision.datasets import MNIST
 from torchvision import transforms
+from .camel_dataloader import FeatureBagLoader
+from .custom_dataloader import HDF5MILDataloader
+from pathlib import Path
 
 class DataInterface(pl.LightningDataModule):
 
@@ -24,6 +28,8 @@ class DataInterface(pl.LightningDataModule):
         self.dataset_name = dataset_name
         self.kwargs = kwargs
         self.load_data_module()
+        home = Path.cwd().parts[1]
+        self.data_root = f'/{home}/ylan/RCC_project/rcc_classification/datasets/Camelyon16/Camelyon16.csv'
 
  
 
@@ -46,14 +52,23 @@ class DataInterface(pl.LightningDataModule):
         """
         # Assign train/val datasets for use in dataloaders
         if stage == 'fit' or stage is None:
-            self.train_dataset = self.instancialize(state='train')
-            self.val_dataset = self.instancialize(state='val')
+            dataset = FeatureBagLoader(data_root = self.data_root,
+                                                train=True)
+            a = int(len(dataset)* 0.8)
+            b = int(len(dataset) - a)
+            print(a)
+            print(b)
+            self.train_dataset, self.val_dataset = random_split(dataset, [a, b])
+            # self.train_dataset = self.instancialize(state='train')
+            # self.val_dataset = self.instancialize(state='val')
  
 
         # Assign test dataset for use in dataloader(s)
         if stage == 'test' or stage is None:
             # self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
-            self.test_dataset = self.instancialize(state='test')
+            self.test_dataset = FeatureBagLoader(data_root = self.data_root,
+                                                train=False)
+            # self.test_dataset = self.instancialize(state='test')
 
 
     def train_dataloader(self):
@@ -87,4 +102,62 @@ class DataInterface(pl.LightningDataModule):
             if arg in inkeys:
                 args1[arg] = self.kwargs[arg]
         args1.update(other_args)
-        return self.data_module(**args1)
\ No newline at end of file
+        return self.data_module(**args1)
+
+class MILDataModule(pl.LightningDataModule):
+
+    def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, *args, **kwargs):
+        super().__init__()
+        self.data_root = data_root
+        self.label_path = label_path
+        self.batch_size = batch_size
+        self.num_workers = num_workers
+        self.image_size = 384
+        self.n_classes = n_classes
+        self.target_number = 9
+        self.mean_bag_length = 10
+        self.var_bag_length = 2
+        self.num_bags_train = 200
+        self.num_bags_test = 50
+        self.seed = 1
+
+        self.cache = True
+
+
+    def setup(self, stage: Optional[str] = None) -> None:
+        # if self.n_classes == 2:
+        #     if stage in (None, 'fit'):
+        #         dataset = HDF5Dataset(self.data_root, mode='train', n_classes=self.n_classes)
+        #         a = int(len(dataset)* 0.8)
+        #         b = int(len(dataset) - a)
+        #         self.train_data, self.valid_data = random_split(dataset, [a, b])
+
+        #     if stage in (None, 'test'):
+        #         self.test_data = HDF5Dataset(self.data_root, mode='test', n_classes=self.n_classes)
+        # else:
+        home = Path.cwd().parts[1]
+        # self.label_path = f'{home}/ylan/DeepGraft_project/code/split_debug.json'
+        # train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train_small.csv'
+        # test_csv = f'/{home}/ylan/DeepGraft_project/code/debug_test_small.csv'
+        
+
+        if stage in (None, 'fit'):
+            dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes)
+            # print(len(dataset))
+            a = int(len(dataset)* 0.8)
+            b = int(len(dataset) - a)
+            self.train_data, self.valid_data = random_split(dataset, [a, b])
+
+        if stage in (None, 'test'):
+            self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes)
+
+        return super().setup(stage=stage)
+
+    def train_dataloader(self) -> DataLoader:
+        return DataLoader(self.train_data, self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+    
+    def val_dataloader(self) -> DataLoader:
+        return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers)
+    
+    def test_dataloader(self) -> DataLoader:
+        return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers)
\ No newline at end of file
diff --git a/models/TransMIL.py b/models/TransMIL.py
index 3cb4e52c6ce1802bcddcde727aa80c260efe2e76..ce40a26b37b1886bf5698ee4ab8ecf07c1e4e2c8 100755
--- a/models/TransMIL.py
+++ b/models/TransMIL.py
@@ -47,7 +47,8 @@ class TransMIL(nn.Module):
     def __init__(self, n_classes):
         super(TransMIL, self).__init__()
         self.pos_layer = PPEG(dim=512)
-        self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU())
+        self._fc1 = nn.Sequential(nn.Linear(512, 512), nn.ReLU())
+        # self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU())
         self.cls_token = nn.Parameter(torch.randn(1, 1, 512))
         self.n_classes = n_classes
         self.layer1 = TransLayer(dim=512)
@@ -56,11 +57,10 @@ class TransMIL(nn.Module):
         self._fc2 = nn.Linear(512, self.n_classes)
 
 
-    def forward(self, **kwargs):
+    def forward(self, **kwargs): #, **kwargs
 
         h = kwargs['data'].float() #[B, n, 1024]
-        
-        h = self._fc1(h) #[B, n, 512]
+        # h = self._fc1(h) #[B, n, 512]
         
         #---->pad
         H = h.shape[1]
@@ -86,15 +86,19 @@ class TransMIL(nn.Module):
         h = self.norm(h)[:,0]
 
         #---->predict
-        logits = self._fc2(h) #[B, n_classes]
+        logits = self._fc2(torch.sigmoid(h)) #[B, n_classes]
         Y_hat = torch.argmax(logits, dim=1)
         Y_prob = F.softmax(logits, dim = 1)
         results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat}
         return results_dict
 
 if __name__ == "__main__":
-    data = torch.randn((1, 6000, 1024)).cuda()
+    data = torch.randn((1, 6000, 512)).cuda()
     model = TransMIL(n_classes=2).cuda()
     print(model.eval())
     results_dict = model(data = data)
     print(results_dict)
+    logits = results_dict['logits']
+    Y_prob = results_dict['Y_prob']
+    Y_hat = results_dict['Y_hat']
+    # print(F.sigmoid(logits))
diff --git a/models/__init__.py b/models/__init__.py
index 497cee19810dc56d0933816137b554d7b3c760cc..73aad9d74d278565976a1dd2c63c69dc7a0997ad 100755
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -1 +1 @@
-from .model_interface import ModelInterface
\ No newline at end of file
+from .model_interface import ModelInterface
diff --git a/models/__pycache__/TransMIL.cpython-39.pyc b/models/__pycache__/TransMIL.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4e1ddff6d6f3cbadfd7f1c0c4686a7806f5896e8
Binary files /dev/null and b/models/__pycache__/TransMIL.cpython-39.pyc differ
diff --git a/models/__pycache__/__init__.cpython-39.pyc b/models/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45428bd7fc4c40dc9fb1116df4a96aea2fb568aa
Binary files /dev/null and b/models/__pycache__/__init__.cpython-39.pyc differ
diff --git a/models/__pycache__/model_interface.cpython-39.pyc b/models/__pycache__/model_interface.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e81c0ccdd33e6fe68a5ffe1b602990134944c80
Binary files /dev/null and b/models/__pycache__/model_interface.cpython-39.pyc differ
diff --git a/models/__pycache__/vision_transformer.cpython-39.pyc b/models/__pycache__/vision_transformer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e278bdeba138e438387aa38b7a83ca7bb7819a7e
Binary files /dev/null and b/models/__pycache__/vision_transformer.cpython-39.pyc differ
diff --git a/models/model_interface.py b/models/model_interface.py
index c7bb72323c99a57672eff72f7e660ba90c269594..1b0f6e19f8429b764e6fd45b96bf821981d0731e 100755
--- a/models/model_interface.py
+++ b/models/model_interface.py
@@ -4,6 +4,9 @@ import inspect
 import importlib
 import random
 import pandas as pd
+import seaborn as sns
+from pathlib import Path
+from matplotlib import pyplot as plt
 
 #---->
 from MyOptimizer import create_optimizer
@@ -18,9 +21,11 @@ import torchmetrics
 
 #---->
 import pytorch_lightning as pl
+from .vision_transformer import vit_small
+from torchvision import models
+from torchvision.models import resnet
 
-
-class  ModelInterface(pl.LightningModule):
+class ModelInterface(pl.LightningModule):
 
     #---->init
     def __init__(self, model, loss, optimizer, **kargs):
@@ -37,11 +42,11 @@ class  ModelInterface(pl.LightningModule):
         
         #---->Metrics
         if self.n_classes > 2: 
-            self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'macro')
+            self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted')
             metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes,
                                                                            average='micro'),
                                                      torchmetrics.CohenKappa(num_classes = self.n_classes),
-                                                     torchmetrics.F1(num_classes = self.n_classes,
+                                                     torchmetrics.F1Score(num_classes = self.n_classes,
                                                                      average = 'macro'),
                                                      torchmetrics.Recall(average = 'macro',
                                                                          num_classes = self.n_classes),
@@ -49,17 +54,19 @@ class  ModelInterface(pl.LightningModule):
                                                                             num_classes = self.n_classes),
                                                      torchmetrics.Specificity(average = 'macro',
                                                                             num_classes = self.n_classes)])
+                                                                            
         else : 
-            self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'macro')
+            self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'weighted')
             metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2,
                                                                            average = 'micro'),
                                                      torchmetrics.CohenKappa(num_classes = 2),
-                                                     torchmetrics.F1(num_classes = 2,
+                                                     torchmetrics.F1Score(num_classes = 2,
                                                                      average = 'macro'),
                                                      torchmetrics.Recall(average = 'macro',
                                                                          num_classes = 2),
                                                      torchmetrics.Precision(average = 'macro',
                                                                             num_classes = 2)])
+        self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes)                                                                    
         self.valid_metrics = metrics.clone(prefix = 'val_')
         self.test_metrics = metrics.clone(prefix = 'test_')
 
@@ -67,18 +74,103 @@ class  ModelInterface(pl.LightningModule):
         self.shuffle = kargs['data'].data_shuffle
         self.count = 0
 
+        self.out_features = 512
+        if kargs['backbone'] == 'dino':
+            #---> dino feature extractor
+            arch = 'vit_small'
+            patch_size = 16
+            n_last_blocks = 4
+            # num_labels = 1000
+            avgpool_patchtokens = False
+            home = Path.cwd().parts[1]
+
+            weight_path = f'/{home}/ylan/workspace/dino/output/Aachen_2/checkpoint.pth'
+            model = vit_small(patch_size, num_classes=0)
+            # model.eval()
+            # set_parameter_requires_grad(model, feature_extracting)
+            for param in model.parameters():
+                param.requires_grad = False
+            # print(model.embed_dim)
+            # embed_dim = model.embed_dim * (n_last_blocks + int(avgpool_patchtokens))
+            # model.eval()
+            # print(embed_dim)
+            linear = nn.Linear(model.embed_dim, self.out_features)
+            linear.weight.data.normal_(mean=0.0, std=0.01)
+            linear.bias.data.zero_()
+            
+            self.model_ft = nn.Sequential(
+                model,
+                linear, 
+            )
+        elif kargs['backbone'] == 'resnet18':
+            resnet18 = models.resnet18(pretrained=True)
+            modules = list(resnet18.children())[:-1]
+            # model_ft.fc = nn.Linear(512, out_features)
+
+            res18 = nn.Sequential(
+                *modules,
+                
+            )
+            for param in res18.parameters():
+                param.requires_grad = False
+            self.model_ft = nn.Sequential(
+                res18,
+                nn.AdaptiveAvgPool2d(1),
+                View((-1, 512)),
+                nn.Linear(512, self.out_features),
+                nn.ReLU(),
+            )
+        elif kargs['backbone'] == 'resnet50':
+
+            resnet50 = models.resnet50(pretrained=True)    
+            # model_ft.fc = nn.Linear(1024, out_features)
+            modules = list(resnet50.children())[:-3]
+            res50 = nn.Sequential(
+                *modules,     
+            )
+            for param in res50.parameters():
+                param.requires_grad = False
+            self.model_ft = nn.Sequential(
+                res50,
+                nn.AdaptiveAvgPool2d(1),
+                View((-1, 1024)),
+                nn.Linear(1024, self.out_features),
+                nn.ReLU()
+            )
+        elif kargs['backbone'] == 'simple': #mil-ab attention
+            feature_extracting = False
+            self.model_ft = nn.Sequential(
+                nn.Conv2d(3, 20, kernel_size=5),
+                nn.ReLU(),
+                nn.MaxPool2d(2, stride=2),
+                nn.Conv2d(20, 50, kernel_size=5),
+                nn.ReLU(),
+                nn.MaxPool2d(2, stride=2),
+                View((-1, 1024)),
+                nn.Linear(1024, self.out_features),
+                nn.ReLU(),
+            )
 
     #---->remove v_num
-    def get_progress_bar_dict(self):
-        # don't show the version number
-        items = super().get_progress_bar_dict()
-        items.pop("v_num", None)
-        return items
+    # def get_progress_bar_dict(self):
+    #     # don't show the version number
+    #     items = super().get_progress_bar_dict()
+    #     items.pop("v_num", None)
+    #     return items
 
     def training_step(self, batch, batch_idx):
         #---->inference
-        data, label = batch
-        results_dict = self.model(data=data, label=label)
+        data, label, _ = batch
+        label = label.float()
+        data = data.squeeze(0)
+        # print(data.shape)
+        features = self.model_ft(data)
+        
+        features = features.unsqueeze(0)
+        # print(features.shape)
+        # features = features.squeeze()
+        results_dict = self.model(data=features) 
+        # results_dict = self.model(data=data, label=label)
         logits = results_dict['logits']
         Y_prob = results_dict['Y_prob']
         Y_hat = results_dict['Y_hat']
@@ -87,8 +179,13 @@ class  ModelInterface(pl.LightningModule):
         loss = self.loss(logits, label)
 
         #---->acc log
+        # print(label)
         Y_hat = int(Y_hat)
-        Y = int(label)
+        # if self.n_classes == 2:
+        #     Y = int(label[0][1])
+        # else: 
+        Y = torch.argmax(label)
+            # Y = int(label[0])
         self.data[Y]["count"] += 1
         self.data[Y]["correct"] += (Y_hat == Y)
 
@@ -106,19 +203,28 @@ class  ModelInterface(pl.LightningModule):
         self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
 
     def validation_step(self, batch, batch_idx):
-        data, label = batch
-        results_dict = self.model(data=data, label=label)
+
+        data, label, _ = batch
+
+        label = label.float()
+        data = data.squeeze(0)
+        features = self.model_ft(data)
+        features = features.unsqueeze(0)
+
+        results_dict = self.model(data=features)
         logits = results_dict['logits']
         Y_prob = results_dict['Y_prob']
         Y_hat = results_dict['Y_hat']
 
 
         #---->acc log
-        Y = int(label)
+        # Y = int(label[0][1])
+        Y = torch.argmax(label)
+
         self.data[Y]["count"] += 1
         self.data[Y]["correct"] += (Y_hat.item() == Y)
 
-        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label}
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y}
 
 
     def validation_epoch_end(self, val_step_outputs):
@@ -126,13 +232,26 @@ class  ModelInterface(pl.LightningModule):
         probs = torch.cat([x['Y_prob'] for x in val_step_outputs], dim = 0)
         max_probs = torch.stack([x['Y_hat'] for x in val_step_outputs])
         target = torch.stack([x['label'] for x in val_step_outputs], dim = 0)
-        
         #---->
+        # logits = logits.long()
+        # target = target.squeeze().long()
+        # logits = logits.squeeze(0)
         self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True)
         self.log('auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True)
-        self.log_dict(self.valid_metrics(max_probs.squeeze() , target.squeeze()),
+
+        # print(max_probs.squeeze(0).shape)
+        # print(target.shape)
+        self.log_dict(self.valid_metrics(max_probs.squeeze() , target),
                           on_epoch = True, logger = True)
 
+        #----> log confusion matrix
+        confmat = self.confusion_matrix(max_probs.squeeze(), target)
+        df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
+        plt.figure()
+        fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
+        plt.close(fig_)
+        self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch)
+
         #---->acc log
         for c in range(self.n_classes):
             count = self.data[c]["count"]
@@ -156,18 +275,24 @@ class  ModelInterface(pl.LightningModule):
         return [optimizer]
 
     def test_step(self, batch, batch_idx):
-        data, label = batch
-        results_dict = self.model(data=data, label=label)
+
+        data, label, _ = batch
+        label = label.float()
+        data = data.squeeze(0)
+        features = self.model_ft(data)
+        features = features.unsqueeze(0)
+
+        results_dict = self.model(data=features, label=label)
         logits = results_dict['logits']
         Y_prob = results_dict['Y_prob']
         Y_hat = results_dict['Y_hat']
 
         #---->acc log
-        Y = int(label)
+        Y = torch.argmax(label)
         self.data[Y]["count"] += 1
         self.data[Y]["correct"] += (Y_hat.item() == Y)
 
-        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label}
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y}
 
     def test_epoch_end(self, output_results):
         probs = torch.cat([x['Y_prob'] for x in output_results], dim = 0)
@@ -176,12 +301,20 @@ class  ModelInterface(pl.LightningModule):
         
         #---->
         auc = self.AUROC(probs, target.squeeze())
-        metrics = self.test_metrics(max_probs.squeeze() , target.squeeze())
-        metrics['auc'] = auc
+        metrics = self.test_metrics(max_probs.squeeze() , target)
+
+
+        # metrics = self.test_metrics(max_probs.squeeze() , torch.argmax(target.squeeze(), dim=1))
+        metrics['test_auc'] = auc
+
+        # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True)
+
+        # print(max_probs.squeeze(0).shape)
+        # print(target.shape)
+        # self.log_dict(metrics, logger = True)
         for keys, values in metrics.items():
             print(f'{keys} = {values}')
             metrics[keys] = values.cpu().numpy()
-        print()
         #---->acc log
         for c in range(self.n_classes):
             count = self.data[c]["count"]
@@ -192,6 +325,16 @@ class  ModelInterface(pl.LightningModule):
                 acc = float(correct) / count
             print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
         self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+
+        confmat = self.confusion_matrix(max_probs.squeeze(), target)
+        df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
+        plt.figure()
+        fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
+        # plt.close(fig_)
+        # self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch)
+        plt.savefig(f'{self.log_path}/cm_test')
+        plt.close(fig_)
+
         #---->
         result = pd.DataFrame([metrics])
         result.to_csv(self.log_path / 'result.csv')
@@ -226,4 +369,18 @@ class  ModelInterface(pl.LightningModule):
             if arg in inkeys:
                 args1[arg] = getattr(self.hparams.model, arg)
         args1.update(other_args)
-        return Model(**args1)
\ No newline at end of file
+        return Model(**args1)
+
+class View(nn.Module):
+    def __init__(self, shape):
+        super().__init__()
+        self.shape = shape
+
+    def forward(self, input):
+        '''
+        Reshapes the input according to the shape saved in the view data structure.
+        '''
+        # batch_size = input.size(0)
+        # shape = (batch_size, *self.shape)
+        out = input.view(*self.shape)
+        return out
\ No newline at end of file
diff --git a/models/vision_transformer.py b/models/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebffe7a868b547806ff19e30290729f9688cc0fa
--- /dev/null
+++ b/models/vision_transformer.py
@@ -0,0 +1,330 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# 
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# 
+#     http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Mostly copy-paste from timm library.
+https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+"""
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+# from utils import trunc_normal_
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+    # Cut & paste from PyTorch official master until it's in a few official releases - RW
+    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+    def norm_cdf(x):
+        # Computes standard normal cumulative distribution function
+        return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+    if (mean < a - 2 * std) or (mean > b + 2 * std):
+        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+                      "The distribution of values may be incorrect.",
+                      stacklevel=2)
+
+    with torch.no_grad():
+        # Values are generated by using a truncated uniform distribution and
+        # then using the inverse CDF for the normal distribution.
+        # Get upper and lower cdf values
+        l = norm_cdf((a - mean) / std)
+        u = norm_cdf((b - mean) / std)
+
+        # Uniformly fill tensor with values from [l, u], then translate to
+        # [2l-1, 2u-1].
+        tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+        # Use inverse cdf transform for normal distribution to get truncated
+        # standard normal
+        tensor.erfinv_()
+
+        # Transform to proper mean, std
+        tensor.mul_(std * math.sqrt(2.))
+        tensor.add_(mean)
+
+        # Clamp to ensure it's in the proper range
+        tensor.clamp_(min=a, max=b)
+        return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+    # type: (Tensor, float, float, float, float) -> Tensor
+    return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+    if drop_prob == 0. or not training:
+        return x
+    keep_prob = 1 - drop_prob
+    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+    random_tensor.floor_()  # binarize
+    output = x.div(keep_prob) * random_tensor
+    return output
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
+
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class Attention(nn.Module):
+    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+    def forward(self, x):
+        B, N, C = x.shape
+        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv[0], qkv[1], qkv[2]
+
+        attn = (q @ k.transpose(-2, -1)) * self.scale
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x, attn
+
+
+class Block(nn.Module):
+    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+    def forward(self, x, return_attention=False):
+        y, attn = self.attn(self.norm1(x))
+        if return_attention:
+            return attn
+        x = x + self.drop_path(y)
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        return x
+
+
+class PatchEmbed(nn.Module):
+    """ Image to Patch Embedding
+    """
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+        super().__init__()
+        num_patches = (img_size // patch_size) * (img_size // patch_size)
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.num_patches = num_patches
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+    def forward(self, x):
+        B, C, H, W = x.shape
+        x = self.proj(x).flatten(2).transpose(1, 2)
+        return x
+
+
+class VisionTransformer(nn.Module):
+    """ Vision Transformer """
+    def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
+                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
+                 drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
+        super().__init__()
+        self.num_features = self.embed_dim = embed_dim
+
+        self.patch_embed = PatchEmbed(
+            img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+        num_patches = self.patch_embed.num_patches
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
+        self.blocks = nn.ModuleList([
+            Block(
+                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
+            for i in range(depth)])
+        self.norm = norm_layer(embed_dim)
+
+        # Classifier head
+        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+        trunc_normal_(self.pos_embed, std=.02)
+        trunc_normal_(self.cls_token, std=.02)
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    def interpolate_pos_encoding(self, x, w, h):
+        npatch = x.shape[1] - 1
+        N = self.pos_embed.shape[1] - 1
+        if npatch == N and w == h:
+            return self.pos_embed
+        class_pos_embed = self.pos_embed[:, 0]
+        patch_pos_embed = self.pos_embed[:, 1:]
+        dim = x.shape[-1]
+        w0 = w // self.patch_embed.patch_size
+        h0 = h // self.patch_embed.patch_size
+        # we add a small number to avoid floating point error in the interpolation
+        # see discussion at https://github.com/facebookresearch/dino/issues/8
+        w0, h0 = w0 + 0.1, h0 + 0.1
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+            mode='bicubic',
+        )
+        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+    def prepare_tokens(self, x):
+        B, nc, w, h = x.shape
+        x = self.patch_embed(x)  # patch linear embedding
+
+        # add the [CLS] token to the embed patch tokens
+        cls_tokens = self.cls_token.expand(B, -1, -1)
+        x = torch.cat((cls_tokens, x), dim=1)
+
+        # add positional encoding to each token
+        x = x + self.interpolate_pos_encoding(x, w, h)
+
+        return self.pos_drop(x)
+
+    def forward(self, x):
+        x = self.prepare_tokens(x)
+        for blk in self.blocks:
+            x = blk(x)
+        x = self.norm(x)
+        return x[:, 0]
+
+    def get_last_selfattention(self, x):
+        x = self.prepare_tokens(x)
+        for i, blk in enumerate(self.blocks):
+            if i < len(self.blocks) - 1:
+                x = blk(x)
+            else:
+                # return attention of the last block
+                return blk(x, return_attention=True)
+
+    def get_intermediate_layers(self, x, n=1):
+        x = self.prepare_tokens(x)
+        # we return the output tokens from the `n` last blocks
+        output = []
+        for i, blk in enumerate(self.blocks):
+            x = blk(x)
+            if len(self.blocks) - i <= n:
+                output.append(self.norm(x))
+        return output
+
+
+def vit_tiny(patch_size=16, **kwargs):
+    model = VisionTransformer(
+        patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
+        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+    return model
+
+
+def vit_small(patch_size=16, **kwargs):
+    model = VisionTransformer(
+        patch_size=patch_size, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, #num_heads=6
+        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+    return model
+
+
+def vit_base(patch_size=16, **kwargs):
+    model = VisionTransformer(
+        patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
+        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+    return model
+
+
+class DINOHead(nn.Module):
+    def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
+        super().__init__()
+        nlayers = max(nlayers, 1)
+        if nlayers == 1:
+            self.mlp = nn.Linear(in_dim, bottleneck_dim)
+        else:
+            layers = [nn.Linear(in_dim, hidden_dim)]
+            if use_bn:
+                layers.append(nn.BatchNorm1d(hidden_dim))
+            layers.append(nn.GELU())
+            for _ in range(nlayers - 2):
+                layers.append(nn.Linear(hidden_dim, hidden_dim))
+                if use_bn:
+                    layers.append(nn.BatchNorm1d(hidden_dim))
+                layers.append(nn.GELU())
+            layers.append(nn.Linear(hidden_dim, bottleneck_dim))
+            self.mlp = nn.Sequential(*layers)
+        self.apply(self._init_weights)
+        self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+        self.last_layer.weight_g.data.fill_(1)
+        if norm_last_layer:
+            self.last_layer.weight_g.requires_grad = False
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+
+    def forward(self, x):
+        x = self.mlp(x)
+        x = nn.functional.normalize(x, dim=-1, p=2)
+        x = self.last_layer(x)
+        return x
diff --git a/train.py b/train.py
index c4b4d7d7c27fdeacbfcd9b56609557f5beed8372..182e3c0f4f0df1096166441731e39794f8599766 100644
--- a/train.py
+++ b/train.py
@@ -3,8 +3,9 @@ from pathlib import Path
 import numpy as np
 import glob
 
-from datasets import DataInterface
-from models import ModelInterface
+from datasets.data_interface import DataInterface, MILDataModule
+from models.model_interface import ModelInterface
+import models.vision_transformer as vits
 from utils.utils import *
 
 # pytorch_lightning
@@ -15,8 +16,8 @@ from pytorch_lightning import Trainer
 def make_parse():
     parser = argparse.ArgumentParser()
     parser.add_argument('--stage', default='train', type=str)
-    parser.add_argument('--config', default='Camelyon/TransMIL.yaml',type=str)
-    parser.add_argument('--gpus', default = [2])
+    parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str)
+    # parser.add_argument('--gpus', default = [2])
     parser.add_argument('--fold', default = 0)
     args = parser.parse_args()
     return args
@@ -34,20 +35,31 @@ def main(cfg):
     cfg.callbacks = load_callbacks(cfg)
 
     #---->Define Data 
-    DataInterface_dict = {'train_batch_size': cfg.Data.train_dataloader.batch_size,
-                'train_num_workers': cfg.Data.train_dataloader.num_workers,
-                'test_batch_size': cfg.Data.test_dataloader.batch_size,
-                'test_num_workers': cfg.Data.test_dataloader.num_workers,
-                'dataset_name': cfg.Data.dataset_name,
-                'dataset_cfg': cfg.Data,}
-    dm = DataInterface(**DataInterface_dict)
+    # DataInterface_dict = {'train_batch_size': cfg.Data.train_dataloader.batch_size,
+    #             'train_num_workers': cfg.Data.train_dataloader.num_workers,
+    #             'test_batch_size': cfg.Data.test_dataloader.batch_size,
+    #             'test_num_workers': cfg.Data.test_dataloader.num_workers,
+    #             'dataset_name': cfg.Data.dataset_name,
+    #             'dataset_cfg': cfg.Data,}
+    # dm = DataInterface(**DataInterface_dict)
+    home = Path.cwd().parts[1]
+    DataInterface_dict = {
+                'data_root': cfg.Data.data_dir,
+                'label_path': cfg.Data.label_file,
+                'batch_size': cfg.Data.train_dataloader.batch_size,
+                'num_workers': cfg.Data.train_dataloader.num_workers,
+                'n_classes': cfg.Model.n_classes,
+                }
+    dm = MILDataModule(**DataInterface_dict)
+    
 
     #---->Define Model
     ModelInterface_dict = {'model': cfg.Model,
                             'loss': cfg.Loss,
                             'optimizer': cfg.Optimizer,
                             'data': cfg.Data,
-                            'log': cfg.log_path
+                            'log': cfg.log_path,
+                            'backbone': cfg.Model.backbone,
                             }
     model = ModelInterface(**ModelInterface_dict)
     
@@ -57,12 +69,18 @@ def main(cfg):
         logger=cfg.load_loggers,
         callbacks=cfg.callbacks,
         max_epochs= cfg.General.epochs,
+        min_epochs = 200,
         gpus=cfg.General.gpus,
-        amp_level=cfg.General.amp_level,  
+        # gpus = [4],
+        # strategy='ddp',
+        amp_backend='native',
+        # amp_level=cfg.General.amp_level,  
         precision=cfg.General.precision,  
         accumulate_grad_batches=cfg.General.grad_acc,
-        deterministic=True,
-        check_val_every_n_epoch=1,
+        # fast_dev_run = True,
+        
+        # deterministic=True,
+        check_val_every_n_epoch=10,
     )
 
     #---->train or test
@@ -83,7 +101,7 @@ if __name__ == '__main__':
 
     #---->update
     cfg.config = args.config
-    cfg.General.gpus = args.gpus
+    # cfg.General.gpus = args.gpus
     cfg.General.server = args.stage
     cfg.Data.fold = args.fold
 
diff --git a/utils/__pycache__/__init__.cpython-39.pyc b/utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3dcd009c36ba1c22656489be43171004bc84df85
Binary files /dev/null and b/utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/utils/__pycache__/utils.cpython-39.pyc b/utils/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1eeb958497b23b1f9e9934764083531720a66d0
Binary files /dev/null and b/utils/__pycache__/utils.cpython-39.pyc differ
diff --git a/utils/utils.py b/utils/utils.py
index 1b7e44f8b1fd69860ebaa3483688aac78a52a7ba..96ed223d1f73fb9afbf951486376bc02c76ae75a 100755
--- a/utils/utils.py
+++ b/utils/utils.py
@@ -14,7 +14,7 @@ def load_loggers(cfg):
 
     log_path = cfg.General.log_path
     Path(log_path).mkdir(exist_ok=True, parents=True)
-    log_name = Path(cfg.config).parent 
+    log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}'
     version_name = Path(cfg.config).name[:-5]
     cfg.log_path = Path(log_path) / log_name / version_name / f'fold{cfg.Data.fold}'
     print(f'---->Log dir: {cfg.log_path}')
@@ -31,8 +31,10 @@ def load_loggers(cfg):
 
 
 #---->load Callback
-from pytorch_lightning.callbacks import ModelCheckpoint
+from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
+from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
 from pytorch_lightning.callbacks.early_stopping import EarlyStopping
+
 def load_callbacks(cfg):
 
     Mycallbacks = []
@@ -47,7 +49,21 @@ def load_callbacks(cfg):
         verbose=True,
         mode='min'
     )
+
     Mycallbacks.append(early_stop_callback)
+    progress_bar = RichProgressBar(
+        theme=RichProgressBarTheme(
+            description='green_yellow',
+            progress_bar='green1',
+            progress_bar_finished='green1',
+            batch_progress='green_yellow',
+            time='grey82',
+            processing_speed='grey82',
+            metrics='grey82'
+
+        )
+    )
+    Mycallbacks.append(progress_bar)
 
     if cfg.General.server == 'train' :
         Mycallbacks.append(ModelCheckpoint(monitor = 'val_loss',
@@ -64,7 +80,7 @@ def load_callbacks(cfg):
 import torch
 import torch.nn.functional as F
 def cross_entropy_torch(x, y):
-    x_softmax = [F.softmax(x[i]) for i in range(len(x))]
-    x_log = torch.tensor([torch.log(x_softmax[i][y[i]]) for i in range(len(y))])
-    loss = - torch.sum(x_log) / len(y)
+    x_softmax = [F.softmax(x[i], dim=0) for i in range(len(x))]
+    x_log = torch.tensor([torch.log(x_softmax[i][y[i]]) for i in range(y.shape[0])])
+    loss = - torch.sum(x_log) / y.shape[0]
     return loss