From d7fe105eceef23555df224fd257187c17ba64430 Mon Sep 17 00:00:00 2001
From: Ycblue <yuchialan@gmail.com>
Date: Thu, 23 Jun 2022 10:39:02 +0200
Subject: [PATCH] added jpg_dataloader, rm fixed_sized_bag

---
 .gitignore                                    |   4 +-
 DeepGraft/AttMIL_resnet18_debug.yaml          |  51 ++
 DeepGraft/AttMIL_simple_no_other.yaml         |  49 ++
 DeepGraft/AttMIL_simple_no_viral.yaml         |  48 ++
 DeepGraft/AttMIL_simple_tcmr_viral.yaml       |  49 ++
 DeepGraft/TransMIL_efficientnet_no_other.yaml |   4 +-
 DeepGraft/TransMIL_efficientnet_no_viral.yaml |   6 +-
 .../TransMIL_efficientnet_tcmr_viral.yaml     |   8 +-
 ...ebug.yaml => TransMIL_resnet18_debug.yaml} |   7 +-
 DeepGraft/TransMIL_resnet50_tcmr_viral.yaml   |   7 +-
 __pycache__/train_loop.cpython-39.pyc         | Bin 0 -> 9235 bytes
 .../custom_dataloader.cpython-39.pyc          | Bin 15301 -> 15363 bytes
 .../custom_jpg_dataloader.cpython-39.pyc      | Bin 0 -> 11037 bytes
 .../__pycache__/data_interface.cpython-39.pyc | Bin 7011 -> 10339 bytes
 datasets/custom_dataloader.py                 |  77 +--
 datasets/custom_jpg_dataloader.py             | 459 ++++++++++++++++++
 datasets/data_interface.py                    | 107 +++-
 models/AttMIL.py                              |  79 +++
 models/TransMIL.py                            |  40 +-
 models/__pycache__/AttMIL.cpython-39.pyc      | Bin 0 -> 1551 bytes
 models/__pycache__/TransMIL.cpython-39.pyc    | Bin 3328 -> 3233 bytes
 .../model_interface.cpython-39.pyc            | Bin 11282 -> 14054 bytes
 models/model_interface.py                     | 249 ++++++++--
 test_visualize.py                             | 148 ++++++
 train.py                                      |  58 ++-
 train_loop.py                                 | 212 ++++++++
 utils/__pycache__/utils.cpython-39.pyc        | Bin 3450 -> 4005 bytes
 utils/utils.py                                | 126 +----
 28 files changed, 1547 insertions(+), 241 deletions(-)
 create mode 100644 DeepGraft/AttMIL_resnet18_debug.yaml
 create mode 100644 DeepGraft/AttMIL_simple_no_other.yaml
 create mode 100644 DeepGraft/AttMIL_simple_no_viral.yaml
 create mode 100644 DeepGraft/AttMIL_simple_tcmr_viral.yaml
 rename DeepGraft/{TransMIL_debug.yaml => TransMIL_resnet18_debug.yaml} (91%)
 create mode 100644 __pycache__/train_loop.cpython-39.pyc
 create mode 100644 datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc
 create mode 100644 datasets/custom_jpg_dataloader.py
 create mode 100644 models/AttMIL.py
 create mode 100644 models/__pycache__/AttMIL.cpython-39.pyc
 create mode 100644 test_visualize.py
 create mode 100644 train_loop.py

diff --git a/.gitignore b/.gitignore
index 4cf8dd1..c9e4fda 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,3 @@
-logs/*
\ No newline at end of file
+logs/*
+lightning_logs/*
+test/*
\ No newline at end of file
diff --git a/DeepGraft/AttMIL_resnet18_debug.yaml b/DeepGraft/AttMIL_resnet18_debug.yaml
new file mode 100644
index 0000000..03ebd7e
--- /dev/null
+++ b/DeepGraft/AttMIL_resnet18_debug.yaml
@@ -0,0 +1,51 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [1]
+    epochs: &epoch 1 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 2
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_debug.json'
+    fold: 0
+    nfold: 2
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+        
+
+Model:
+    name: AttMIL
+    n_classes: 2
+    backbone: simple
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/AttMIL_simple_no_other.yaml b/DeepGraft/AttMIL_simple_no_other.yaml
new file mode 100644
index 0000000..ae90a80
--- /dev/null
+++ b/DeepGraft/AttMIL_simple_no_other.yaml
@@ -0,0 +1,49 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [3]
+    epochs: &epoch 200 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 20
+    server: train #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_other.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: AttMIL
+    n_classes: 5
+    backbone: simple
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/AttMIL_simple_no_viral.yaml b/DeepGraft/AttMIL_simple_no_viral.yaml
new file mode 100644
index 0000000..37ee074
--- /dev/null
+++ b/DeepGraft/AttMIL_simple_no_viral.yaml
@@ -0,0 +1,48 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [3]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: AttMIL
+    n_classes: 4
+    backbone: simple
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/AttMIL_simple_tcmr_viral.yaml b/DeepGraft/AttMIL_simple_tcmr_viral.yaml
new file mode 100644
index 0000000..c982d3a
--- /dev/null
+++ b/DeepGraft/AttMIL_simple_tcmr_viral.yaml
@@ -0,0 +1,49 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [3]
+    epochs: &epoch 300 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 20
+    server: train #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: AttMIL
+    n_classes: 2
+    backbone: simple
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_efficientnet_no_other.yaml b/DeepGraft/TransMIL_efficientnet_no_other.yaml
index 79d8ea8..7687a0c 100644
--- a/DeepGraft/TransMIL_efficientnet_no_other.yaml
+++ b/DeepGraft/TransMIL_efficientnet_no_other.yaml
@@ -6,10 +6,10 @@ General:
     precision: 16 
     multi_gpu_mode: dp
     gpus: [0]
-    epochs: &epoch 1000 
+    epochs: &epoch 200 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 20
     server: test #train #test
     log_path: logs/
 
diff --git a/DeepGraft/TransMIL_efficientnet_no_viral.yaml b/DeepGraft/TransMIL_efficientnet_no_viral.yaml
index 8780060..98fe377 100644
--- a/DeepGraft/TransMIL_efficientnet_no_viral.yaml
+++ b/DeepGraft/TransMIL_efficientnet_no_viral.yaml
@@ -5,11 +5,11 @@ General:
     amp_level: O2
     precision: 16 
     multi_gpu_mode: dp
-    gpus: [3]
-    epochs: &epoch 500 
+    gpus: [0, 2]
+    epochs: &epoch 200 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 20
     server: test #train #test
     log_path: logs/
 
diff --git a/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml
index f69b5bf..5223032 100644
--- a/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml
+++ b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml
@@ -5,7 +5,7 @@ General:
     amp_level: O2
     precision: 16 
     multi_gpu_mode: dp
-    gpus: [3]
+    gpus: [0]
     epochs: &epoch 500 
     grad_acc: 2
     frozen_bn: False
@@ -16,10 +16,11 @@ General:
 Data:
     dataset_name: custom
     data_shuffle: False
-    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    data_dir: '/home/ylan/data/DeepGraft/256_256um/'
     label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
     fold: 1
-    nfold: 4
+    nfold: 3
+    cross_val: True
 
     train_dataloader:
         batch_size: 1 
@@ -33,6 +34,7 @@ Model:
     name: TransMIL
     n_classes: 2
     backbone: efficientnet
+    in_features: 512
 
 
 Optimizer:
diff --git a/DeepGraft/TransMIL_debug.yaml b/DeepGraft/TransMIL_resnet18_debug.yaml
similarity index 91%
rename from DeepGraft/TransMIL_debug.yaml
rename to DeepGraft/TransMIL_resnet18_debug.yaml
index d83ce0d..29bfa1e 100644
--- a/DeepGraft/TransMIL_debug.yaml
+++ b/DeepGraft/TransMIL_resnet18_debug.yaml
@@ -6,10 +6,10 @@ General:
     precision: 16 
     multi_gpu_mode: dp
     gpus: [1]
-    epochs: &epoch 200 
+    epochs: &epoch 1 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 2
     server: test #train #test
     log_path: logs/
 
@@ -19,7 +19,8 @@ Data:
     data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
     label_file: '/home/ylan/DeepGraft/training_tables/split_debug.json'
     fold: 0
-    nfold: 4
+    nfold: 2
+    cross_val: True
 
     train_dataloader:
         batch_size: 1 
diff --git a/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
index a756616..f6e4697 100644
--- a/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+++ b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
@@ -5,8 +5,8 @@ General:
     amp_level: O2
     precision: 16 
     multi_gpu_mode: dp
-    gpus: [3]
-    epochs: &epoch 500 
+    gpus: [0]
+    epochs: &epoch 200 
     grad_acc: 2
     frozen_bn: False
     patience: 50
@@ -19,7 +19,8 @@ Data:
     data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
     label_file: '/home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json'
     fold: 1
-    nfold: 4
+    nfold: 3
+    cross_val: True
 
     train_dataloader:
         batch_size: 1 
diff --git a/__pycache__/train_loop.cpython-39.pyc b/__pycache__/train_loop.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65c5ca839312787cce6aed19f346ae40a7c04d83
GIT binary patch
literal 9235
zcmYe~<>g{vU|=x&IVEYL1_Q%m5C<8vFfcGUFfcF_|6*WZNMT4}%wfo7jACR2v6*t1
zqL@+`QkZg>b6J=e7#VU|qu9VQ%sK4298nx#Hd78~E>{#cn9Y*Ilgk^$o68r)2j;Wp
z@aGCd34qziCbQ)T<_bj#fyLQ#gmXorM8Is09MN2{C^0abGe<mEB1!_x=E{-Gm5P$e
zm5!3mm5Gwcm5q|km5Y)C%Q5B1N6CX-sSu?Irj>G)qg0@Bs!?k03@O|xJS_|<JgJP$
z%u(v@3@N-Rd@T$qe5pds%uyOCOu-DA{4YVFqRDuR-zPIYqa-ggFWomkr8FniPm}Q$
zhhuVbX;ETwr6$uYW=ALITYQO0#U(|F$tAg|B^miCASEfOsRhaT1(lkNw^)4g^9wW?
zZ?U^&mOz=DAw`Lqd8tKid76y3gq;$LQ@!2tb5dLqOA;a0XtLg7cFe21#p05gTyl%W
zC$qTZ7KeX9NoIatV$Lm=kjjG8WRMGyF&mUq9KyiBkjfCnn8Fanl**jSBFT`-n#v~0
zkiwY4G>5sJA&oJGIfbQ#CyG6VH<crmGnF%iF@<doOA31mM+<8dR|;neR|`WFHzb^*
zc-k3Q7@~NC88rEB3A^SMr{*T*q=w~}K*A(7$4`^_7H4jLN@`Aga!&Crf#i(T<m`g{
z%)FBLg2a-H;#+J*sU@XFdC9C$2ZAV&W4Rd^7??qc&4H1Dp@gBCVFBYp21bS&<{E~0
zrW%$ShInR>lUo={SZWxun3@?~7@8St7~)w=*s|Cca6mo6)WT821d>?@lH)AlYG$Zm
zi020BNMXrlDLPfclf?_>v8J%~GSx7|^Oo=|;49(D;sT4ZgG5VsviQL~j$XzRo-6?{
zn-eM{n8FF>bM-Qn@MH<0@PrY(bhd?zwXAs|Ap25SQZQ^w0R=D{*bM#@mK3N>DJ&^S
z_N1_+z-@uB(^)~L*05%YfTE)p<VVpGo-8r2i#Wi1@f3z&22EbSWJXZPLNOBq0|O|)
zgFyi-#=yXk&QQY;%U8>o0?NybbC@PF6|w{~tYj(z>C$An#hRR7npbj*JvqOqC^flc
zCBsUlTTFTew-_^v#2FYEz{D?8{fzwFRQ<}F#611-{G#mQg2d!h{g9%>ykcKZA6=K!
z)B^XS#IzFq5>Nt<&&kg(&?~6C#buL|SzMBwpIcxj&A`C$8RW()11xDtFFrmqFS8^*
zUe6{cKRGd{*iH|j6|4V*(f!xKuz;b4aUmnbe_&^8GTsu&$xlp)Pb<pLjYp(UO{OAA
z1_p*(j0Gzhiv(ewECM;PNQQxd0fdWW85kH|f<p5aM|mo!Fi1_&WV*$jSDG82oRe5w
zoLYQ~J+UmcC^0?t7He*1a#6k}W05k*Sk~M`5Ra)y1!Octs3f(xBwmy87F$74YFcK6
zro=7Q;?jcDqFWp=pWoulgB#3|lb;UGM7LOr67$kiZ?U8#mL%R{Ey*uR&bY;tmv@T`
zQeJ=y2sUsgEG`1&sVEMxO^M0Lw|Ky+b5l!-GLwsMu{wqZ`8(ed@CEao^K)`ilR?GG
zEiUK$jMO~u#DapvTkLLz!O8hWskhjIQj-&Na&B=36s0C-7H8(?-Qo@|NKMX6%S_HJ
zsl3ISoRgoIdW#2a8O#cP=ls01QjlWb#FC=SidzE7FoF17u!ttxEtcZcoHR(F3xR@1
z8${@Uf{h(&&Pv8xobiyLh>tH)0Lg=6MvH-g0hFM@<s1Ve7ZV#J8>1Mb3?m047o!j(
z4<plm7A77>4lvEe$nu|s35op9!7ISX1(xGtWCPPk<WDwEKE^5;d?^!>Us!IjCMA|6
zXWZfhQ}LN86}R|6kpoI+sYM`u5h!UUgKAVzatC2h-3!W-pa?2n!N9;!!x+Qd$xzEu
z%UZ*-fMFp+ElUT(0>%`^g^VCxCqoI-0!Vqyl)~J?QNxhM0x9Sgu!7XovXwApvDL7F
z3Y%W0TJ{q5EDku2ql9SzTMf%X##&I}#aY8C$&kXv${@**!eY+A%wWS%!;l3LWe17Y
zurJ_R$WX(X%><QYtKq0&uVG2yn8Omxpvmb6_BKWi1o;l+&0-N)3Byy%Si{)C(2Sh_
zAej)B*>YDh-eS$nD<~}iizBj|6DW#7*{#YDTLgg9SaC^e0k*Q?C8!iz$#{!7B{R23
z6l4%6iD<IiV#~=-&nzjv#TFS~P?Vo^i#0MnBeCQbYffTPYK|uNE!MQ0{KS%5?8Swp
zsi{?|Mf#v5$XZ;In4Dc?1Y)z~rY7dyVoNMa&rPhj#prg6y*NLuBsa05NDnNXS(2Ko
z$yH<zveOwvxPS;(5aAC}z!+cT1LF9B2w$*BB*f$HAQ4cKzQqamn-!=y0|h1nGZ!NV
zBOkK@BOfCl6BnZxqa0I}EZ%TQ29;8vbOyqF3=9k$3=9k)Cl$MaD_Bs<N?~eYr~y}`
z3mI#fN*ER}r7)w2F_(Z+APb5ZO9?|3a}Bggt7WNSSin-lRKv1>bs@t-rdrk-h6QXj
z%phJ3%R;7Fwi?zNwiI?7h8ng7><bxO7-F?zm}=Q;IZ8Nc*k?1OaMW<jW|+&|%*e=4
z!?A#KAp<z&vDL8KFcemmfGTbtP*or62Tr-X;FQZL$xsWbt|2@wNrqbP8m<)XY^I`n
zDGaqdc|0Y2HQddNHC)mRwLB#pDLf1KYdC9oYPf5-7P5eBD%_C53-S}fd~j0&)G!bL
z`6Y!fg};}Xk)cFzflv)wI^#mdTHYGo1;RB9St2O{(hLj1ZQx)AO+n0R7L*CV)hsC9
zz||}`6QnRmGC(RJO~xWn#DimfB`h0MfVmI?lnB5z4p<pD3z*@I>ePb#<c#>#ycBHt
z0Ba5<<T!BB1XuFWpd<-qfeE;~z$!rQBFlVGK7e=>m-zu8dqL(Gd4gD8plkxI!Ag^>
zm=$al5SbcW5>?5AD_@1`S}TRb<Ya~FS{(&w%SEBOR=>Iy!=a#p7aS0vf&m;$7>%JK
zPy|98Dgkl^I1Io9JRovF;vk3KVgcn@O~zZGvLZdTimNEKxHP9kFS)o(QxsB)FefLL
z6a|4afLi25CLmTIh_C^9j}=sK6yIVkD9X$$xy77ZP<o3suQa!yvPc)C23#M4N|+*0
z*xX`E%P-1JECID+(&N#q(_2ghDYrO5y$83V#N5<dY$f^e$;D+wpz2hU4_q|e;>#~B
zDJU(8hd2{#Z+=qoEzaD;3UI+#e2c9lu_!&Y<Q8*cX>t*$6L^ayJGHX-7F$_jPHAfK
zEyiR-9Nc0CwJg{mhCxCM9!|ITKp_(kPo|*KuNYKmGcbxUiZF>V@-YiA3b1gn3or^X
z3Ni743OGg%Mm8oMMlnV)CIKcM#wsPksUM|gWoLl2cEP3HG*Hb7ZUurm6g5l>nQB=|
zm_Q}i0+xjgpmM5&t%f0sy_pe4r!e+1*Rqy?OCv^x60R)n8rBq0$;i^nyntsRLoG->
zZwjj<Ll)lx{u<`l3@L1LnQB0y0t=ZzU6vH~Y^I`)5{?CeH4Iropi*uw6R6B6;aDJC
z!v<o38|j>Wpq6%KUP@{OsHL4#nwwV~a*HFlAT_z9C@}{lnV6TCUvi5jIXAHYlF|kB
z;q55}Nb{*Ej)8%pibFp+Hy#w!nj%GZpd<rMTj?NH28aN45Q|bkEO>&t#axh6a*Hi3
zGrhDZ^%iq+UhysVjMT)E+{A)gT<NJL@lYA|oc#3k)S}{BT&WcWsYRK&sd**0xDr!R
zU<!DXON)w9^GYDGe2cv}u`D$$Gaa1bz>UJ<R84kp!Ycyjm|T!E;z7xt9h7)MDVa4T
zEj~H-7E4-YdOTVRDoSGjbqE<4iUmMb8KVHB3L^)j2(ug$2cr;Ul?JxN0ct6Oj0U$m
zVa-fNa0qJhfm+vjiMgrq@wd3*<8#41`S|!-Jn`{`rHMHZnIe9W8$nfQ5iiK0%*7=|
zMWAF`1nQ^WV$RGfDT)FqDFG267vAE9HcyjNb8_O5J>dva0xmN^9so5*ia~WT2cw*d
zm=K3X5l9piJs{t~F{l~=af$^PQTy+pPB2Qpoi&v$l|6-Z4r2;i3VRDnGh-A7xDU@U
zhY8$wXNlrUVFb&AMmkucxKnvjc~Tftc;_&K>bVw{C|)EHffT_OmMA_X5up^}7M3V}
zs2L(DqAe^@0#Ke<ig*i4lwhh*DoZMJ8j~ag*dB=#$rhF<;Z&g%#uTYJOexYSGA%4o
zB2YcDDRM0=QKC?ue2PK~OO#kD4_K#Sic$+plsMQt9<X_gDav!0QdCk@TUeqbz+(@R
z!3>&ew>ZILB%l!#@Q4Gb1D2MblTv((JGG)9zX&qg0BLlC8rYzq1YuCHf?^dEi^XQ(
zgasdI0QIJsQa~hg4QOP5rG`0$Ns=Lp6*S6F!<@n_2^vda17}?pNE7%LE4Y_+i#aDh
zU6Z{?6BN_nRMHPhG(50;0diT<1dt@SAP1MkERYNVQV5p1#hRU1lwMo}PA%}_6P)In
zK+QlNkn<Tqy%bOnhljCB2p$6PezGPb#K(*vwNMOVgM17w3_*s}Fl2$6`;6JF!3--I
z{WO`1CWBlCa>z<Xh@-&Pg9%Wwy~UD}pO@MNas<dA2Bs=Lgw-G^O{QBMX_+P94nD*I
zV0BOjfY=}mb^xf;Qv(WO22d~-O=VzUSjkv4gMooTlL=xsh>cbPfC|UFywv29KDZkg
zs{|2lNX`e#qEzxAZ-X${R#2>`Fa|SdGF8d><(Fh+=A|o?<SP`V7N?ddWacT9WMmdA
zfQKjiAXbAa4G0Sn3ant|6Oo+2k8lE5PLuf-3n&h6v1ONoXh>LqTmcOV@BkMmIfB|i
z#TMY?SHqCSn8j4W+{{!g<i*g$2;naP4RW!9Y6@_z#0VPVS_m2N;)MF7JhV791>_A-
zpIjlKIJKm-ARd&`iW3w-LmLVu8L0{hScVf56f*OQOA_;vQxy{PQWT0y@{3Zzit~#y
z(=+oDa}>(KLmR~k`Dq|=E(It6r*%C)O|Dy9nZ+<OZn1+$cT$UPae;cgkm9rm6wkM~
zk=#@SN?%26pll(K0~$C1jiN#e-&<V8C5a`e@hPC;O=)!5c!cCiW=I;D14^)SK?FE@
zzyvs@^5o~m7nSCLv=o(00Yw$aMh0dfMghhuF+_ABD?>>(pvVDXa0Efh2T+NS#W<T`
zE)%HhP{LHhkj0$EQo`EIRICJzFy<_l1)%Dm5fowIA!Ft$4_wg`53cZujw8J)zTgnY
zpb$?#cLg_p9~T8bP4=QBkl(@Cb{>c|A4Dtw5#VqaK!iKQvZAFRIdCw632;~m<mbgF
zrj&uw0VKR;!^4XS)EKLhLJ2PvrMSZkl&EVMKm!(x;8JD*6L=JtVF5FQ&SJ@8g^95+
zWU-|%nlmsl#Ir#|%?cK3pu}5Lnx~Mckd|2j8eImbRZs&5lzH<(c{Wu6TD~e|<maU5
zfjbL+n(VjuK+Rfk^98ICT)Nz1E-KA~%Rr4O0)-p2yA8_ykb)3fSg}KWzYr8epfF}&
z;$Y-r;$Y-KDTtt|P(lZk{Xx||xF7<x-a+N>0!Gka4rmAmoVr+NGo&!hWnKtwB!c9#
z*cPzYFl2Fn(i51^xd7DR1$R7a7_zuacv6^qL7h<0a13iz94yp8EiX`(Ma#kPfC3c{
z;D7{|S74QBQyO5|%#vb-{DP9q+{~)fqGG*Ykg-I)U(9+1B}FqC7#RFCc_D$v53&(y
z5cL*wesRGqmaP2DJaDZB3DQNNY|jrG3<IS!XihI$0uto_&6Sj-rdQsArdou(x7dqP
z3vv>ZQ;R_1SA?xh=0i%usd*`@K+ywI%)q3;Bmm0u5?Ir4YF-LT9D#B?sGtDR#h}I<
zcs7ACg{g(3grS2WixD&g(ac!OSi)4o+`#}!!_DA<tt_?<hAj3J=3bUsCXjdyQ#=cZ
z&f-|WSp$wuD31%2`oU~&P#1!cAq5h1Zm^gG4TqPOK+55Q)FN;P4^%=!ieQCQ=uC?O
zxC~Z+rhia6Rwzy_(L?bKsImnG0(dA+4K@@fSIY<*%;{jrVoG7CVN78ZXMm+T=(q%P
z(K=8u0q%3J2i3^TMH@i^n?MAp^#U6qXH5f-Xd)s6RGxrGEVh7@f;z-i8Ylq|DwA=H
zQD`z1!5c-OWCiY?fGZ17yZ9DMNo7GQIIR_dQV~LJ5!kL-pftn}4wU$k{P>jAvdrYv
zqBS6AfD(fyw88)lDo62@fv2fJl}>67q;>~;0p$8yT+k{ZH7^C^-(pZpiGfK1G?2{1
z2g>Fm*gS+1?x36w3VU!)2Ne?Fyv>lw2%1>&1LuU5j9?0DS{BOBi$}{%dqC+KWF!M)
zl?;mgXbLr%iWY<X42p7;fj^LD5C*#mGzticd2l-x(f|Y5P_!2m+0f<}*cfmn0Vcp<
z19BNS^CBio_TzLJnnK(m1Zo9?$~I8999&m0K<b4O<}6U-harnKo(&R$@$8_+;z}mJ
z5KRtnQn)1sF2CZzEsuCmvjb#nYSAqbh<I6I4w7gQCn%AyyLpCyDpBwdO-br4rnJ0U
z9AE=7^U|?JNYP3L28Kf*e}Xa-10xqB7bqPdl_;q73~rw)frlWnpfj4_!3TychAftN
zR;d5LeTt$3plD!+cu*AL5l{xl^ht<oFvurGM?g`26zn;WLR|h>&A`BL0^|=+5eM=I
z4`UUkFQ6u1q#jV22u?kqITujVJ&U1?u}G{0)KF!rVQgj!2G#QnHH?xB!3^LAy(TlL
zhYf1X#mC=bPfsmLEGa1h*TT>?{w?P0)JjMOL?klq_;{$&`1rGM*D`^_TojS^kyL3i
zK|%m)qK7*hREpFvfC3*h4+xt1hdGm}=oBdI*h&jh5=&B{Zi5&Jb{3cbX8@k~_~O(O
zNINwCGRSeDJjH<0q(oMtDOJ<~iknUl0qQsv^?+F2AOh4qyv0#alwXiqR8o11B`H5Y
zr)WAz29(T-Km}S6s3is-Xee3+k^oH(6|Duawu6Y>AOciA73~ADKv@x-R*OIpT?7id
zqVph$UJwEH1egGKEJ5BW0(GMrLB0S*DF>q%2NR<Jvj8(6vk<!wn-C`-2OpOhh|k8y
z%Ety)t*LoSu%Hq=br=tCpBI4w0X(=FB@C9*%gfVCE6oGXASUJ%8Gu5U3rQdpB%lIT
z3Z2!{1CQH+<_{A=(~f$H(6tocS<74O`NevmMFY23KqR;_y2YHBlw1Vr&lQ1&aBs1I
zmQ_Hy6*3qO)B`o@iuIBbi&Kk0V}iF7a4CWJUO{nyO&PeLNG;L>r-5Qfi6RavctGhw
z55$Pi%qvMPN&_`lKxO|ewvx&MP>uwRG!%hj3p9{a1S*3egG<Js&|oV_%u7iuE&>hI
z-4e`AEGfvzFUiSF(krM0v2L+J=%Ol+V)o+H#H9S9yrLRV%z_d&IC7A35;&qkaeRxz
l2GSF=1Jz2!pgCvIBq(S?l!sA(kp~Harb<B#Sq>I~8~|;~>=yt4

literal 0
HcmV?d00001

diff --git a/datasets/__pycache__/custom_dataloader.cpython-39.pyc b/datasets/__pycache__/custom_dataloader.cpython-39.pyc
index 4a200bbb7d34328019d9bb9e604b0524127ae863..99e793aa4b84b2a5e9ee1882437346e6e4a33002 100644
GIT binary patch
delta 4619
zcmX?F-dw?#$ji&cz`(%Zt-m43A$}sC3}ezpZ576P?G%wYj47fiVlAvuI_?Z9;weHc
z3@JjXs?E$%TB*7z64^`>n2O?37*d#Wm~&a8SQr^n7*ZH>^rG}qSr+IoWXLi|k(6YJ
zGE9+5HJr_mB0ZNm$_UDqnadny3}wsCWsWj|vgPJ7N13MTWtpX#)~Cp)DD*NhGNdR*
znWt)`8mF41D5WT;ny09wsP?i&S)^*GTBcewGp4ACFf=npS;6?JRuUk-bt-F?O&0qC
z+f<8Gt5oZBriDx~Oi^}G_Nknyb}5XhDyeqOj46sK>b)#c4iK>v#T1QR<|v00hA77r
zdxmtzRC5uAD5q2h7lv4aDCbni6wPeb2~0%}j0_W)3gc3pW0<2{qFkfgQr%PCni&}x
zQXNwy%^9K`ComQ2Y@Wqr#uP5z&XC5K!kWU?!WqSq!k)s>!Vtxp!kNO=!Vtxl!kxm?
z!Vtxt!kfa^!Vtxg!k;40!VtxoA_xlVD6SOY6p<E&C~j~hFhudRGq5m3@dh(!N^DkV
z6=G6zdkKoTmkcZn49OgzD1>4W1_lN}1_lOZ1_p*=tH}xMYK&}?_bZA`p3BZ&&tAh2
z&z-_p%TdBp!;r<9!qm*vpH<6Q!n=U4hI1if4YMReEmsYf3q!113{x$4El&x54R<qR
zmOw3c34aYk7DqE9NF;?>oS~UfnxU4rhIfJBLWZKO8jyJ`HQY7~H4Is-3xsNT7cw$3
z6lR4pq%g2BurM?;Gcx1}OpaC(tQQ8UXl8U_h?S`2t6^9mQX*Qzw?J$mLo;Iy>p~_*
zhQeDV;tM27#Iqz*SW;Mf8PgdTGS)D}OQo=-u(vSO@~3csX@L|@FfEwE1*U~kxWTk=
ziFB4siFB513J)l3dzoqkN~E*oK;fGr*vk~t%vd6wC7&V$3gBL*66q|36rmK+6p>!0
z8vYXLEX5SD6vh<s6p3DDuu7#G!4l~#<rK*jsTAoH8IWj=P>FPwN{VcXT#6Jpp|UhH
z)_|;4ot&wxnXXo%zCfZxqnV*ZeSzjehIA&d^R>XD>RH+;Vkw%vj3w$>IxyA(-4rd5
z(-$(PGlSLYfz|4#@PMIqFEffs1}O}|44OKV_i|3-4q#wl_za5qDxJv}C50y&2(eGL
zS8<rUgUf<3aPmj4yvZN9tR~OoHk=&K?LJwMM};wDvKh~Hrk5g<`FY<kHckG+Ys;i*
zI60iJdGiy#^^8oKB9r$CNKU>ZV9n;mz`#)CJy}+;j<Ia=96_ncI|SpIs-&RyNN}+;
zFfcfSQklr)AR(>#q#7m{hFGIo_8RsDEDISHu$Hi8u`gt-Wv^jdz)=HXF)ZMOvRD^z
zfmn<Tg*th9V4WIRbYfGVR>Hl2r-pSQV=Ze9YZh+__X55ehAjU26s8oW6y|i6g^bzE
zMQ@O$QkX$<Fqsm88WvD~EfD~z0kgOl2-dKsuq<S%Wi8=eAXLM=kg=AzggZ+(OC(FQ
zhPj3zODu)8m#LPyhP8&dh9!+Tm_d`x?-pxGQDSCZku(DX!%Jo}1_lOAmLf(FhY1vN
z<{)A^NR;swcV203d~!}=adB#~ruZ$6^wg60l*E$6qNd5GgtbB<85kIDv6g1$l~~;3
zO)O2%P0cHb&&*9sPc3Q%nNk3jFUc=T&bY-|T#}faeT%g?Be5X$7E5quRccWdSiYbr
zGq2<pdqHYZZfQyCWH*sU{Z3HyXo1oPlK>+NBNr1FGY6vpBNrnVBOjE-!^Xv8!pOnM
z#mK`r`HzSXcLW0i!%C)G%sHuflY>Q77}rg%5S_}nW-_0cbYdRJHOU|%5k!F8rO8&r
z2V$p#h%N>OhA58m;>`HG#N5=PH6V8}XQosX@q%Q*1gLB(Dr8__*a&hPs0?EOxrd31
ziG{IBadM-K_~b`g;`Ja24p=cZfsuisgkb?=2SW<uLZ%6f{R*{A=?t~Z=?t|jHB1Yb
z7BaXn#M;%e)-ctuN-}^VyOl|jp_Z+NZ2@x$OBN`tu$8baV6S0Y$XLr(!@Ph4#A0Nq
zVOhWdN=b|ig$gyyMFk~n3pgQ>U!Nya!j;0@!O+aa$WX$zfV+kviw9I3rLZn!sb#KV
z&f-mBOJM?8Q^K}@FP$NUeGwx#<z2%f#{rWo;p$)jr#LQ<Dlm&5oW6LPnPQk~nQPf=
z7#0ZBFn825EM%(XC=mqLAfOt9qeQ5NAxjukZZ<PEGuCp}aDezVoHZQc43Z2f+%^p1
z3@J?TvR25DVFF{ZnE?YM7)CNMGL(qaFl31~GlG(HHp>LYqMj0Qh)5@62jc>X4hB$}
z4H8)(sn)@m!XpV5XlG1g1Qnz$U<+M9#WX{_a0*i}gC_svyW&kQRf4X$1tpaV&iQ#I
ziJ5t+MOF$$pmYz;7)79LpvjCR0Lm_!%n(Ts8!R<BUc#Ai>*U1}hTJznNkxr;fuTx&
z@<mB?uJqKr)QW;4yP|EA1tgQzwt}QVxeT1TiWETH%^+e6hyW$Al}tsuC(o4BWqdOE
zl&s9;8XeKezhyb=HKlH`WSACI-ePgf%t^h)lAT&vTr?M?36%Yc=7CrXKna#BttdY?
zKCd*lpt1;*!iyGyq?UlBSc?+#Qu4t0t!NQQbTLR&up~b|EwdsuCB8VbDm6YSG5r>A
zd}2yUd`W(Ma$<5u>MfSM)X7;=!i*ax*GVZda!p<!C8rw2Rg#&L3eG6Sw^%^zA^}kL
z;R0oy_@u;=<cy;2lg~*>xgQ4E1j;wJxZ-p26H~xN$1YIzF#=_6R#4VvVdP;FVB}%s
zV&-DxVFB^D7=@TDz!@7PF2iWFSzcO^Q5@`}4WK;E49U2lhC<QS$p<xsSx+-CFg%-_
zFH^y|XYvafMaFNFxn!qlfpT%tEl@$mRJ0r9lszB<lzk!D8sxH~BMb}-e<z=k4Q1Ro
zSwT*O@%Ls^xpqbeaBOV@sRx?^imjE*Mc~3C2_$+I<Pc~H0y6d%Q(8e0l6tVSW-%}@
z{GTkMV9dxo*;7GVWgST8CJ+HCM~XmZ6oH&xbO$7M4@AtGJVQZNbqa_HcAgjm0|UtN
zVo>vlgOP=sFP1}<gPDVqLud0X1qWtE&&kp%?Mz;blM}g3C*M?2X7YwGja8MId?3tT
zRb@us$s1MUjX||ukqgLhR}kR=B0NEa4~Xys5rH5g2t<T}h_K1dYEdrPAZ8AT$OWYv
z_M+5+oW$f*NGf9miGw{<1Y#9~RB(gK{IbmA%>2BfW{_~{<cDe!qU9iF35X~I5lSGU
zVlt<?88<hG*#aV3C;O{QMSx1=qEwJstSPBUrRld=GV*g%i+VsRdO-xZmhS_x`a#45
z5HS%%Oac*;K?FD$K{=}k9F_?T3=DFc52{NsCQgH-3UCqvMX)A!5h${Xnn2oSgNQT`
zF$YA<1rdm>HXp<V*KS~qtdPXG3M2~329S)s799Ve_)eW1qL~D)n~K(h)NBOHfD%E)
z<eQpJZtx;NleGw3{er8iS0JU})CNlSSgNYb+;ngbgcK&*!G?n3pnCI8EfYq@9h2W_
z`z!ARnF97dDAN_~gLA>bv7do~p?`9Mjuqp!$+LALoWV&RTtI;mToK4&n#@IDQLq3w
zb%RYi05a$hhyZiJrh)R&tjT7&w)ICr630NqIS_FiM4*KsIMJU0Nt^`{V2^^bb`dyl
zfnvSrJV@dKD4N(IAq=jkFM>obfe3I|ZwIk1g9xxgu7FrvVCR7HR@3A^y3@I@fdsFE
zh#TNg0@*!h@;1FBNP!J1ABt{+Wk9)b*<@?|<=po{${&D;hhP~{+E_iA)xa$283O}@
zn<j6OFh~-VCyPWutj8b%l)j2Uab5(9UT|9hJ&%J7DtZD^40bS>0H;#0qn|P`FbGYa
zYaq?|bn;FEC&tdne+;6H!0j5Aq8N~_SP&5hBEmrh5H#_jOHCFu6K6cQd7+^pBOfT?
z-eLn4|M|H^;gcU28Oec_gQ|^N?4@~`Y57IDMIdu;F(;;^6va(8Fg9e|J2}bNg7L)U
z0CSbe`;FDPKY^_L0wyMZG`5fhwa@l5GB6agFfcH1FmbULfka+1FflM_GEVk3nPFQr
z1Ed;MooccdfvYDDkRw2iJa7$ki#a*5<Q8jHYEgdiE#}<JydqEn&}6;EoRpY8S;^E=
z`V=VhgY4h|RTj(=j2tX%j9UNMn2Sz<LWdh<V-dt*lSK@jA$jZ@$ky+m01$o+(hF*e
zF)*<)f?&~)$<}6{lz)Ph`~ne7j0_BrssJ1j;5ZOsU|{$(Ilx?1iItIo!B3L~Y6VC=
zH~>JdDPjZb1g8y<GhR+!Z(hS_K3T^?iz6>TKEEU*wP<pJg?BwDxfFp48@D1**{jJ~
zlmRjbT)2QsxJ(ci)C^9aH4_A0+G{f1VlPh2EyzhNf;7L&^NX@mi;6WFiv&ToN`UeN
zV-Y(e1B0eV5jeAfiz#qKvlk?mWaMNfL5h+EleH{ElfcFA8&C}`3@(-R^78c3O7oIS
zGV}8ibBdOOT(ANh)^4|$lZ(r4F$Z}16tRJX|A2(q3KH{D5{qv!6{Hk_>VPP|#GIs3
zP*Xn<WJGaM$>fKY6&j$#3r+>Mm^1TAia@ak&Ly{)@{5aJf=mM!Y9A(dTDh=+;<_k$
z@?|Rz0UnU2Kt(B}<XkdY&_sN)k2ODI#^e}lBhF%>dImiKX#xJp)2$Wx*;yF5z)b@-
H5X}kzkaKZ=

delta 4613
zcmZp!I9kq^$ji&cz`($;es*&bOUy(*8ODN*+A56otXbNrIw>M^m{LSj#9CORbln+J
z#8ZS?7*d2%RhyZkv{Lm_B(j+%FcsybFr+Z$Fz2#Fu`n{EFr+Z%=tt?NvMexI$dF~2
zA}PrbWt1Y7YBZZ6MS3oClrfYoGnYBa1j?43%N%72Wy{TFjxtNtug@}1HA|6CQRro2
zWJpnrvPjiPHAyv1QA$xxwMbD(QSD`mvP{)UwMw;YW=v5NVQ6NIvWD?fttCKwn^e{;
z+bs44cBz)B)~Pn>ObeM}n4;{X98x(`?Nb<2RZ{Jn8B^3#G<sR093f&U>M5GN%u$Xh
zz6?=LDGce1sTLv(QO>E3E)211Q7)-2G0ah}QEpN0sm`fP3p`TY7BaFjNHU~Inlmsn
z*f30BDl|y-Om%K%WMoM7Otk`80Ag?MVKQS1muP25V@zR7VQ=A#VoBjh;cQ`uVol*n
z;cj7wVoTvk;ca1vVo%{q;csDx;z$ul5o}?I;!F`r5pH3K;z|)o5p7|J;!Y7u5pQ9L
z;%R4KVTj@lX3&(}EXgXwr0(bTQiOql;UxnL14A+=C`O@Jl!1Xkkb!}LnSp_!IAU@j
zyBZ_=<O+7#dX5@~c<vh38ishD6vkT465bkyEY1|BW~TnETCNhl1^hK!3mI#eB^heD
zYq(t)V%=hxYI$mTO9X0oni;bMYk5ipY8bLOni)YNDa_&w&5Y6vwR|;v3xpOj6cyBf
z%wwtHv0<oT$YNa}T*J4Jk&&UWAe<qEfrWvEp}v`!ks(jOkfHdD0Rtl#Mlvuml!$<o
zH#530#7fli*Dx#)EfK5XUm(7ap_#FUZ6OmQL*auGi3O4+5?N9yEGewLjOmOE8EY8g
zrNJS~P%Dtak;2}>P%D_i38sZoxWKe<3OAS*DUr#NEs@EpmjeZP3STc%jbMpPmV637
zD9C%6ni)%EvJ^nUo+8rAR3ek5m?D%S3S!p?l*nW$rHG|4riiCV^fH50D%S{=$YiOc
zNTx`oNT<kvL~DdgWU^FKWK-l)q`*m-rJ1n?WUX3?Vv1sl(j2A~<rJ04Je;!i>LnTr
zBug}#8A>!3Xf0$&X99ag8!W1krIR9-qS?z>qLHNwV=d52(E@p6A!9l-Sgk%-tw9P8
z7;5)2qnKn^!;r<3q65ySAtgo&j8k;Mx=d0Sf*CaRCQsm;#uzwRoomiyBW|n7f4NL2
z&*gTXT*<A%7&3VtcNSylWKN#zOhpotAMw0lY@U3c*H-QpOG;*5ew9dCVsdIyetxz-
zh^t$cS(0RErfE8vkFS}7BPlUCJ1IYJb3flkMz<mk1_p*({Aoq`x$y->sU<~;nR%%x
zw*=Et6H7{qQsYxAN{SMbOY(~}MT__t7#NDuK|}_KNCOd>AR>Emgn%8JCy41ad5%CG
zW6fk~L8-~cf{AP#EIf==Qj-n2q$Y0=6y#!KU|?_trFEIfrv<g@t7@2B7-GF@*=yJr
zuq<R)z*@qV#lDcSmK~IKQ<$=uiXuSizJ@)8IfbQ{sg}KlZ2?yeRK)^rD2sIg4~WIc
zP^gm^K+wc32ovkUMlvkm1!)Iaip9vf622PNX2ulOY^I`DCHxCO*-8+Unp4;pvedHX
zah32d5UOEa$XLr-!oNVchIt`lEprKfmPnRpmRJpQ4MUcA3P&$fEprWP4RZ}k8gnp%
zCa2#mR!~R><Q2t%f}g3NM3d<jTV`%zdTQ}4zM|BU(xSZhlGME7{G#F_Mo`HCDj<ra
z85kHe8E<jtmFC7L=Oh*vrxt5U+~P=2Es0M_EJ-YCo~$XX6%+w7lC?B5uf(D#8x$u!
zAOYUQ()8TaypnjZEk$i0i9!%j0g`3TNzJ>(T9RLsoN<e_C^0W3uP6&7&RSfOn4CR%
zxp1R(A1K-kK;;&b03!<{7ZVpV2crNZ7b71N7l?%MKr)OxOgu~?j9e^Sj3SIxvXc$D
zL?++XG~y0tU|?9uRJ4?VfnoAG5f$F8ATyajW-?Yuf>nJMDP>$gIa5?RHXr1;WDt=E
zB0z4{WJ5}Wy$lQtQ5@yPnelmvxv53#LC$5)OsOd11t|v;ppvzyh=GA&2gn|9P09#z
zD;MMBhoVOHpn{kKR`?1rGBA`dEMV+lNMT&a)W58jDV?E~xrS*0(?SLphFFtYmKvrS
z7D<K_rWB@DCP{``)*99Y%q1*Ykf>){zz#_SHOvb*KrBXvLa7qA1)R_X!39nV+|5ie
zOts9lY&8rEc<O7II~W!+)v}lHg6kgU6qXi_622OSEPhZ4+sxR^Sj$lZ$_o5795w9X
zpsI}3h9R6Gg$Yzd!%OHLC?zyFeKs@JFk}g3vrJ$t;wurZVaO6}X6$6_U|b;5!2l}g
zK_UxSMLQT%*d{CL2=nv1fC_Mic>WY7Q1YC-O}uGxu#N-EN~WUilMhH3a-Ro<sTu<V
zLzVtyLoVUT{F1VaJ0|N&CaY}+$$*k5I0lLoKwMC?7HtEuK+&+0sc7%yZIZg&w?HCb
zLkysXh)mw6%M~hiizUOfpz;=rTV_t`Etc%m%HpC$Ahn>BT(lU(S_X<;E^sN8SDIT;
zSp-TBMJqs}f+hL!X_*zNDe=XbRjKhwiRs{^JGo0rjdAJZjZ%tUoQPC&i#@|MJ}omR
zHHs6Ia*`5Dk~40xlw{_l772h1V~j7_38FbbYT{w3X(bb+*gHDeP+CN3KPWLNff6k%
zA0rDR52FAh4<i>74<i?tCBta2xm;S2Q5@_<P*PvX3~>NB4{x9Rkz1Jc3<Cqhoyo^#
zDj4@oc92zMd^0&ocA5z&4Hw-3WpbvXy&!Mz0}-1+1SlYj!0GuI0|UcnkPtK^3uMJ6
zua{$+>>;Nw2387Eb&CsBDC8t2rREfW+}tbI&d3dpC{RWyT0dD{K~(xWC@(@Y6v&=i
zOlbv0P<aPgakirQ3=9lkC+8>_GqLzhUZ9|@vJqtK77zi-Q$-*Hia;JIx(5<_03zm3
zey<?QIDIm&qBLg`hdKu{2Rn!CW<x~>W^Q{%28JTAn{_51)DfQCuad>+FnOJ-`D8;?
zB}PXut6Nox(P{EJ)ksF?$sB6&#u6YiLG?wE3y9?oB0NBZH;C{B5dk0~5JZH4h|tNi
z)S~KhK+Ieakq3%R_M+5+oW$f*NUC84iGzJm3}TglRB(ffzOu~X%>2Bf7LafmNSG@m
zKfu$+6I3qWVhc(IC9on;d0A8nQc?~gKnbC!62z(k5!E1Ka)r9Gb~A{@4I)}W1gOB(
z6oe!Mv^unR@<DYm?tYM}2_V9A@<(+Exv3!LG!QWzL?9AC0s{ks$Yu!*DaOQ^khB6$
zSfIqC$qkORBCy>HK-$2S^g<8|Ek`W@iGdsi&bOMZ5SOh5iGp$hBqMJCCst6rM^9d@
znFKDsi#CDOYzE7K5<}i(1Fa@Ec$U{>EdmwjMd0H54M-`z;yW`p9Z`H2?F1VNikiaB
z0@@~wjJqbg>i8?~2AKl({~i$Q0Ehr{K`FiHAOizK>*Vb^R*X9)f6|F?2B&jyfdns$
zG?|OQqF@1V5(jHP1TyFdhyZiJrh#(Tgvq_Sw)MwB5+^_exbQm(Vxfg0IPsqaNt^=_
zV2^?_d{GjJn+hT>fQXBrXkv$iFgWfnfkZEZ2yj?~ikzaWATHP;*FY>Tuya6ptYUJo
z-gNF8Ai<j;;ubiRKz2`>%&VWoeH$bQDm9Akf@MItan9sP`pdZ=f@B_nh{s?VP>x<W
zImW=O{sjXAgPSIAkuXRS)HNvr^+bxEg1DgcRRoIjB2e^#>reE&4l=0d8AvhM!C(TM
zO2LkP&cMLH!N|Z+ECgn-aPzry2!o3PDGn|U)}rSOlLZYO8M`J28%7&}n*=OH(ICTO
zKtwEv2m=*P&@_oIH965toN@c+ABKjEe4rZQ78|I+lAl`?Hrd?RNDizVRB7B|FU`wL
z%P-0;0-1Y@IWZ-rD0XtYu_5FB$vcfL823$HV6HM*%tW303&_fEU}Ca|iG@-Z0|UcW
zkOx>m9$?~PF9M0YWME=o&}6#BoRpY8d9KL}`<LK^J{J<X;2Mep<dTBaqTE}o#TkhO
zskfMu6H9KfR;3o@7vEye&CCNgeKc8%xEL82CKs7nN}pz6V7LIXiU(AoFiS9UFtagg
z{byq;Iz3sy(1{ylaS_CQld}z-Avy2|$f-X;;UIhuq!-kPVPIlo1i_+Tle^76DgOp3
z`2!+YK;<qthQI_khCmVXbn*goRV8*t1_nP(7N`{<_26&-xuysdhef>Llml|coyq^r
zYZxsj*H~zAl_cjD#g}CkCFV?CXW?zZ4YC7N7`YXJifm2RB18`X)J`h`cN0J@%Jf+?
zLExo*5hw{x*0Ge<5oH9mYKu5Qlvfcr+kp!#aNM#NB$j04WF|ohm1UDFEJKsP<?}mG
z^G_IDZ0qIa>7|wCC6{F8=OyM8tpfRCHK_Pzbi2izTwHdGIl$AW2-Ni``UetbD@e>s
zNi4p_RFF~xswATL5_6JDL5<@?kP*d2rIYooDhxo$7##7p*h0Wf@mtK9c_l@l7z5{<
zTTJ=IMXy1Yfs46Mlh0eZuz|CG)MQC(4_Qzgg9}%1S$c~tu_!$^vEmkMK~ZL2$>b7i
oWyZA0Gpvm`Gll9I^aP|ucqU)AR^(-4VdP@uVCG_!V&q^10C3oF{{R30

diff --git a/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc b/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..20aefffd1a14afe3c499f646e9eb674810896281
GIT binary patch
literal 11037
zcmYe~<>g{vU|{H}-;l)a#=!6x#6iX!3=9ko3=9m#W=sqWDGVu$ISf${nlXwgg&~D0
zhdGxeiiHs(#u~+%!jQt8!<Ne)#SRu@$>GT5jN$~dS#!8@xuUqhY{ne!T%IT%D4REm
z7i>0PE<Y0kBSWr0lps__C`u@WA%!hRI9DV}1S%#PB??w879|d*C88w2v}BZ23PTEe
zj&!a}lnhucM~-Z+T$CJ`&6y*gs}Q9KW;5o<M=7N+q;TaZ=c+`hFfzC^q;RM3v@oRb
zq^dSEN2$3pr0}NjwJ@acrLt$KH#0|RxHF{irwFt#qzHgjXr`#AYNZIxVN4ND5ouwK
z(spM^5k;0y)kzV{W}3iMWP}k)x>35REDQ7&GGyteh)Xg=8Kg+08q8)$k(|pMWe8<U
z&1H@<g0iLOGDjIh*)nsPqfAnDvrJP>Qe;!)dYKp*Qskq|QZ-VIQjJp-QWR6oQj}7Z
zd)cDQQ#n)3Qy5dkQ&f7HqbyP_Q>{`hni-oJ85tllsVb@F&5S9kDQdkeQ5Go-QPwF8
z>5QpnA`DSBD5@=77-IFJY*TGxn4|2X?4uk~ZBkhmIHuY!WMpNKWJnP=XJBTqVMw(G
zv7)ReFc$twwM?}E*_6T<%%G|95|pC+G#PKP1SFPZfasLOl0;3$TU;)QC5b-yi7BZ?
zAkmV-lw3{5TdbbBiRq~z>5%-8)V$*SqA(*(##<ter6u`psfi_}MX9b8B}IwJCHWw2
zT#2RWxv6<2sYS(_jJG5_bCVKt67!N%Q$U6nr<MdK<`(3n7A1omfsC18Nn3=0fgzP4
ziZO*DiYbLLg{g%hiaCWjm_d`}mQYY)UV5rueo<~>PG(hNNoIatGDsiH9A*Xv1`yvF
zltn=1)G*dE)i9+nN;9M|Nid`^Nizg9XfpeygEc(JcX$qFFg&nd$)L%6i={X<C+!wn
zG1!H-m@{+Ji+C9r7{J6YPyLMi+*JL_oWwl+^8BLg;)2BFRQ-^m#Jpl(Paj>E)YJm^
zqQtZkeNbQ)r<N4!CzlqN<mbj`6{N?5#B#s^t5;BYiz7ZhIWZ?EK3)doP8N_G7+Dyr
zq%nP`2UF>%$pX@rmzbLxAAgH0K0Y@;r8Eb`=82ClEKSUT$P@{|+zWOzgb)OYaWgP5
zNQ0~cg%1Z~5r~fw2Vkdw6eXd=K@@X}csoNHV+v~uTMK6tOA31mM+-v~YYJxyR|`WF
z8zi$vu|qOx6bB@0MscPHrU<n#L~(&LB|{W<ifD>h3qur7I|B<t6mKwtro=4)uK;&n
zPaja)fP_!tFGdE2pwxn*)Z)~<l46C#JcY!hVsP?O@NjW6RR9HLacYU4f?Ix(LUCqZ
zdQPf>hp9qxeqKppW?pKMq5_wKf&zr_%qu7@Q7A|(O3W>`0t<tr(lT>W;|mf?GOQFp
z89<>VU!f!;RUs`uCndEAW>98cI#?8Hoq|SwUaCTVkwR{1PDy3~$WD-BKt>g7>cNf9
z%}+_SQcy1`O3cht2Ps!CNi8l>hs%Tf2u=)E3Z7{SAw{LBItn1MVui$<9JqZ7nYp>C
zDVd2SsX3JjnRzAo3PFy(dJ3V%sR|_-nZ*j3X>bEU{sB22=FH4ug|y6)Vu-HX)ZF}{
zN<9Vtl8n@%^2}n8WvPi}P*W016w(rNic=L5^HLze1u_hTVFoIcXXfO9-KUU~Se2Pm
ziSQ`MzWC(C<c!q#;>@a4D+RyO+@#bZh5R%~EP(??p(G=*L?JmbPa!E)0jxMhAu|u`
zO)yVEBPBI0u{5W|)>u=I3+yMim!JypB?AisLo%qc1LX}67G+>y0HtG44lgcYVqhp?
zXl7WzxR8O7A)TRyA)YA{L^6X(7O)6w4O<OEJX;NO4MRM84Py;MJVy;v4MRL<4NDC}
zJXZ~S4MRLPSS3#iV=YGsZw*5hX9`m@Q~#V=&Jw-_{56~l8Ecp&8EUy|xLg=wD`J>x
zxode!1ZudO8M6dyxl06U7_vB;85amGWN=}K4XfpaiPZ2)Go&y}GeFczGt~0c@GTHt
z$WSz)M5Klfq`sN4=vEC+4MP@74Yv(L4MP^|0?``2g^Y|0g%fHRvUm`@!h&#y6b2Rs
z7KUbKMut2ALx$pi1`Lc~7|FoMP$CA>-OT915Gz&7U&F9KyhNgge}UvehGxbZ)`d)r
z3=<d&g-WCrNS8=u$)vEPu=X;hGcE*$hinR43VRDftw0I~m=;Xo1k*w(Twq!_g&RzZ
zl*nbtm&j!)r0{@pcrR0pV2NCoA}E)q2=+2HGnU9@DWwR3a(ORPiCmU)icpGZibyY0
zjX;T9mP(2kxZ0EGWd^HMtr04b%Th~`Op!{FPLTnL)(Dr#WvPQI&=jc@`4oj-mS)Bp
zkgXaiiYbaIN^_V}lv7kvq`;LUYcpdybBa31{B*Vy4G>+!5U*LHwLrQ=yP2UxYk|%}
zhIA&d4|Tz!T3LE2Vkw%vj3ruG`Y_f4gA^^0?-nwqGlSI{g4G(O@PMIqFEffs#wnnh
zL&q=V7GqHqo0E^fvv=@I21W)3O~zXsxrr6vT#(ENN~%x{VzV+ZFa(3jQvn80ip^q7
zVXR?DVM<}{1=nYOMLY})47b>W67y2>a}8HA-eLh2<13kNF=r+w-r@|-&rL1K%uOv`
z$#{!7u@tE`Dgxz&TO2kyiMdHBiFOeT3=E$^_Eu@5m#2F1@tJv<CGqikHaYppi8;k~
zdN5riJ2!pSOb|F}4|B3nku0cgkp~e<AOc=pf*n|-!oa`~2@+RjU|=BGy60!lfWRBL
zb;d+k7e%&pMe0bAVM2se(bTcZlnAS0U{+}|Wq$tu|Nl!+K;L5Zt@H>E(`35ET9%ko
znpz~sz`&r%lx_?LukF**z~GbpEq32Z_r%=XM3B;S5Cg0flKMD&E8TK33w(mW&bS5T
zhPu22C18-wp!^a`15M^zjM=w1vr|(Gz)2l!fF@J=4!PFXfjj!_)2l_+xj(F$VqXN)
z7Lb{moT|xui?uj8F(>sFOLAFa4pi666b1%{mnT33sF8My)3?$!C$YFBGg*`67HeWo
zK}O;&*5b_c+{7XakjJ6^0{bTIoqakaU3{?DWV*#zaf`81lPU8uC{T34{(t!g#0E8G
zK%Q|-OUukl)nqL)2gNR1QGQ8cN$M@u;*8Y9B82H6fBNU8`lsDu^{sR)P1j_+#aMum
zU_q4+I6H&ts1k++j5Q1k8G{*CGWuyU-D1isxW$@SRFqf=(wqosdV!(|krlc?o&jYA
zP3Bwt8KtT5#Tl7tCGk0#xtS%m_=^)uPzA~obC3lzO>VIkmlmWJff|C4JbjB3Tn}gF
zrRCq^0yDw&>Mg$FoXnI|pTwlp9GA@Gl3T3c>i8BvLJ2q$Z?S;dE4Mg7vfu{MEv_7p
z9tfK=FFrXZvA8(3_!dV}Vmd?zPik&KNo73P{Nh_I`30$Yw^*`@^Yd=8fZ8!dpw{Ir
zmg17s+*|CKC8<RznMJqQGK=FuG)qoqamg*V#Dap<yp$qP*?5aTJ~1T)+#revd7}6h
zb8$)0E%u_+;{2Sl)LV>>x7eNYa|`l|Q*ZGH=O*Ulq!yR>CRPMzR;At&EJ_76J>pA1
zP4Kk*qTJ#l1yCGw1*aC4rskDoCg$7{bxcVK$@eHtb<9f%26a`yeTG}S?x0NIoS#=x
zln8PXC<Eja-C_gLr75>q!AbrWD<nnmgHtl7`2<Nzw?v_d72Ng*Sqw3h4Js`FZv7!S
z>lPQ16G83nTO6>S)Ga<xBqOA6i9qTCkNl#{Do~3y2jnDhzb7R>H!QWNBr`b?BF5$e
zF8gosW#*<MTP6WDp)?iY;V@9~aEmXaxC}|9rf3l#C_#b?XmGI!sz!=HrP?h{P-KGR
zr6?Sv22}ps;w(x{E-fm~1Vv2REzy$H+yY2Kh=-(_l+5IkB72Z>&=5rwTXKFzeo@IS
z{)+hId{CdS1k&n`;smE2aFXCGhE;L5xZq49B(^aU+XRVix{~o0XFMc<#>Yc48>mh!
z3SwYj&;eDsb)fo?k&B6q5edS0Of1ZNjC_ndj9kn@j8cpoj66(SjC_oGtbB|T%q)yt
z|2ddB7@?4fg^`C*fRTrhhf#=;g;9u6gi+=%6ARNHmS0RfOkX$@S(sTE+5T{_voLZo
zaxwAz;a~&lVEb1kiL)|B>9vB|xuD_-+-p6=z`#($uz;b4aUo+2b1hRXb1h3cLoI6!
zLl)x#riBc(Y&8s7%qfi7OhrjGOfC$uMz!oU><d^HGAv*%VasA)$XLr>!?u8<2Et-k
zzzJosF5m*O7#Rw6^7OzuHL&Q!rarBNdjU@k>q5p_))MXoyfw@V8EctKxU=}O__G8+
z?dvST6sBILTIL$o8s-|7G*H`y+3yxBI3H*--r~+H&4t8NYO$u!Espfm5=hc22W3}*
zlKl9T)RNSq+|0bp;*!kdB2Y5C#hO=|TTpq6EwQ+yvLLm{8>E!AG&8Tn;ubG-gdsi?
zG%!(A1xj(LVEK~#qU4NQti>gX$=SC!6N}?Zz~dFSnDX*&aix{!C4)u`5_4{`=jWxy
zXXKZF6T2n{I3L_%&CE+lt+>USlvn~X1e~V971k}z{L+$mh%jeyYGFJm$E6g3+H{bF
z2QTGsamUA}r<P=vq~^xQPXwhxP^*Z6Nq~`sk&B6onS)V)iHnhokq^XT<YMGu<Kp0A
z6k)28!IDzo24U1gph6a$SRnNfC}}h^1~U|Sf$P6p%sHufMWA|WB_r5p=#FHMkIzZX
zi;v#{av;bo2F5B0EDiuGN(NPWpsWPKYzzzx><kPHV9Ra5ZQ2^fX2vYWT4r!MU@BuM
zvZ!ImVgmP@SW+0l>6Uo`OATWUOAT`kYYNjG7Emt-+?r<gd-?zW|Nkplpw$?tOI|b=
z6q}$5{}v~>E{o4hsnBEv*DFPdAh}Qw0rFiDD9$z6AYKF)VvV3gz)@bD84qekgN?ey
zoS9OA763(Q3=9mrL7oSN9|Jhoaxt+mRw-Zx0&4Yyl5IH|7#KhS0t%pF8%73(bcR}{
zbcR~y5{3?jW=2rw)P*6|rk16IX#sNy%L3LKmW7N985gjXFoClvYYE2!PDr+@VOqci
zVlgt*FfZV$VaehK1yZ3%4O3B73C9AS8m5JewM-?vHB4E2&5S9G*$hRCYM2)Ar!axk
z<!O`%q%e0dG&3=Rx{`tmxH=d>;w2mlglZVFgi}~R-M)n^wM-=<pe}Hha5Ga3Q!Nvy
zC|Mv{!_>jBkg1lvL<~H{1L_iTl!(_bWJ!QJm7o#IS`JWWIZL93qlR4^G{VSX!vO9s
zOExpsFl0$(vrJ$tvM7<RVaSqfX6$6_U|b*raREqVfoun33X>$*MeU4fjNm~d4v_gS
zpe`yyyhI9jFoPyfRTXID5j?yD8<7Dg0EP6-vQ)^R0jQk<9wh(`Qh-tmc<7?EASJN`
zG^PO^V^9F+W9UEwTsgSu18QR^6qgj0CWA&T5u+JJpaczWxmF3ff*Lpqptb>cILS((
zND9<Y1;sK-Mg$3furVkhfwFlGL#$K`Q!Qk~kwKE7gCU!#h^dA#jDeA%kO@5W%2c$B
zfq}sfoE|lqZ!zf^++xhU#aO1vSOiYMh=PRECOtJTwW6TN?inb(gBp!hdYH)*UQHwQ
z?qC{=mZNvNKzSXMeT6|zwqgLK*cwm;!dSzY0(Ni;qa?`DMi71tINoX)Q<x<gKq+Vr
zOCD1V<3y%HmS6@=Hn3N~VFu1<;F7Xv3COWaK?KMvnoLEjL0OV1B~z2}7Ee)XZa%2F
zlbu>w3=RTt%lH<T4X9lRY9iQu1nC6(Ri8M&!gcy-%0bEq(8wjEwqQ>z1~<oxK$)~C
z4rKHs5HT4<Oa-}%E3GI$Hy&KkfD02y0W%FGzYrwPTAYzska~-=G_Sa@G&Qv<^%iSE
zQD$CAQ7l-tC^0W3uL#^U5-iD&Ps^-GO#!!(<C7B8Z}Ebgwk7$HQL<Ysd8rj8w>VSt
zN^?_-5=&CS1$0pe$S!aT3)CgL#Rd+d)LYC2`30Jsh?4CVC#X<^)IeM%py5kU0aqNw
z31)+G0=PE1#RU?E)Ko<)L2l##Wf@Qr3Jz~gCP*^d0J0U-TDrv*4^KfyK#4*RRLZlm
zFtV@+Fi9|SF>^8TF@c&jeBcHW52Fw(s39c7D8wuTuE+$KtEBP76iVd*$}AuZ%4@}-
z)CVq$7|IxnBx)GbK_v-84U-K+9%xvNp@vBk)W2c`mt#yd3|U|iP}wH|8i--4n&_C4
zQjDdv01ajqrIwTy<rOPH20~L(EA$jXGE$3D6*7wzK*PA;as^&8fYOjcT7Hp2YFTPg
zr2;s_KtZa5JQNHXo&^;lpn-iz(E=Ji&CE$fDnyJx)h{z7PC<235hxL9GDD&f<a0=t
z0p$#Ebbyn8(N5^tO)&>3>{+-N`IrP4tK=|46;u_$k`qcr1Tq?wwZQ=kng9R?7o#Kt
zC_97ZKY|%Fnf$;RbR{D=-N78S29yv$l`X_QV3&c@1-Oj{N=(d*3=GAf#wi06A0w6q
zEvPLAQitLokYhlO0^wqigTOTlXex)Jh7nZigC}87N>h*)kTu|<2~^;MMrLD$Y8g8~
zQ4jJ9gCqmA=wXBuIEO*O4Jl!Yz#dry@(tV%MH?9y7`PZ27>aNdC?KbU8+pY@g(oO6
zfQuq<h-fl_BkvY_W?o8Waw@dc0^0))CN6{-9H4*#`GSE7OO`;WLkTEQ$pi`{aLgru
zTmD%LH6Y(IN-}_A4%D<OV=PjtVOYQfY6O5Qc!&xnNrqZxcua%FsafDW7D;H20hEnE
zZLu1b8dh+-oVDtXGbk;WfaVckEn<blyh_kGqe5kVsX}>TUJ0lIPE1ZtEiP6_R7lP!
z&C3Q2?1R$3LRwLNE@%WFo_#=dF-QhHCldlP7CdW^o0y%dP+VFBs+bc?Ksq3qHM1Bz
zp^yrm2}n*XPF2V(QOGPtF)lA3JZVz`vJ^fS0nW6D+y=5XQ^6Ke*~4{1au}%MgQO8q
z+SX(Sr+G-e1a+@qZAMU{xy6)L0Inzz>cKG&AwZE_)X%`czzd39P;Z%moe$JrWC8Ua
z`IvBI2vBfnGP@OtgA$|!D4FsVff9BRXtb*cG?j6SJw84qKRG@g++QtP08+RZM1Y#6
zMIc9l%Y`CPp<HwfBmr{8EnaAkH#rqFdkAqOIKg*=lz?3;#=yV;8U`r_*~P)g!p)b;
zq0Yg~!O5Y|A;7`P!CC}Tfpz8&)N3yWg&2JHl@nC$w=hOALuLx1SW-BG88o?y>KPaq
zRx)M!fCk#OK-i2R_J5EwKx5j;pwS&rc?QBDHYjI+Q#okL6FeKk04o0D8A1JqcqZ_i
z3^S-#0Pe)BWGV6k8RZWm0ziZ($QD*`Zp0>q6cpf40F45ufZ_~Xs6u=FzZ7u|jlc)V
zQF13I0|Nud?O=Cb0Od|lBOKHZXIuzsXftIq6rCvn^=_C-n6p?HGSxDbFfU*O<<l5u
zP`|yFwFIP^9W+k_DvQB1lO#h5lQ^i!%~k^HZL=+8tOd<dv6+Lr4q1!~xM~=(xS{H~
z!A)mS^PG7BPYp{A`$DE#7O<%xRlHDDyiipfB_LI-H5?0>YFSH|7x01l<Di@m?moa=
z$`5jB4O1|~N@hPu0}fPktrUPXo7h0jn*7`%P>A1RPA)UL#gSZA;+&t8Uv!Jt+26-M
zC_c#D$;cx(>=tWrK~84LEq0I&P_qaeXQ1ZQE%ws9Owh<!5hw<4F(;;^6oI1g7He*5
zQF<!O9FO2Iqab&umC%+BKPaL>z3N+x8MhdVi`If#Fr4sl(p#MH5z<?npmrp9^b|ej
zLDh{mDCR*4je&`UNq~`$QGij7k&BUoQHqfVj6q^}Qv+mr-cOUI=oH9NXW#@Vaj#@W
z%gvBX46+%?HQ;;+PK==3q6BghsDk5Q;$jEqpO>JO3Yv`I_9IHN10^|7#}u6GGQgc|
z&`5I@6DYD5GNv%pFt&mUf2JC6n+Vk3T)<qzT*3hEW-~+7NifuaSWJ=-ni)#7NHQ#B
zNdxsJS!!5HSZf$d*lHM>8Jn5Xn1UHJS&@>%OHht2S_LY+*o(leWe!j@7o-;D7J;(-
zE#~CJl3T1*sYUt4x0rJ?^B}Dq)?3U;iRoZZK}L-;Q!0vYaf8Q?!OJR&Z*hb4#1|x{
zq@<=Gg%>!eKtmF?pr8VUGXt{(BL@o`qt<^m<|2>`6-q6n_yk!&VL4TF4wNv?g9vbR
zg9)%lKn0K=$Ro%FP!XsQL5b^OU&It$1O?6|5OEhofI|pOfP)E?XF@^2L@2iwftJGI
z$d5%&P)$Y{56XEyAmd><4>Wa>o0ypwA72D2^L{bvSLr~O^uX8e#1|ChXQd{W=qKl=
zr0S=nCY6GR*E93<l8ei#<RGd+HM~Ati@uSmS-g>{S!r&SB}5(}1s(*=%uA0iNleN~
zE!GDYnehRR!SN-@xkd41nMH{?dY}oEDl>GgAWMo<OZ1a7Qj@b0cEp2?(1#9!7Keal
zpi?U{i%a73vx~r^1-%Rm3~oi%puEdkgc#We4Ud6GJi*x$HZTiX%amGF{1Q~`XfoYm
zhpc`oD*6P{&RFymL^<7J$uKRbECLN_LTU+iP!B05GYQh^29*y{+~Cnyz0#7*oZ=!-
ztM-;Ol87G2H(&zP<4i3o0!0aU5-3UxtgI}vI5R&_4>1=~Q~@ggA*MnkqlCdK_44xc
zU`^M=oT3?^*q;ef%jkBCIk~v(7IT2752V+B1tiQ?keHW}SbU4AAf+fCB*2%LlT@0U
znpXmzxG65m1PO|v3hEV9f-Nfo4Vc{GfjcX;sJIA}r*E-k=B9(@<Uv^&GV#QenRtr_
ztO*u|pkn+MTYeHKJ%LwT-C}o1EJ-X*EdjT}!1cf_W>D1$D&D}g$}J{Pd750Fa*MSf
zv8bf@7E4BcZfX=KI32}<d<U8HE6UF=0mtnv?)=h{g3=Pmq#XyS-JOzIR1A*jTkN1v
zFHSAF#gvj$v<(zcpg!p>#^fk&aMve3IlnZo<Q5BP#JcD;NFFi%%9dD^o|}j~cnZ&0
zw>WHa!RycLK(oEYpfMK?CJsg(Mh+&>GzkwA4<iSo5NKKjHi-hBJ^_jHFbl8>)icyH
sun0&C@C(RtfpvkVVtANAdij`Gm_%4Xd=XC2$O8`}2QwFw7#j~G06clsv;Y7A

literal 0
HcmV?d00001

diff --git a/datasets/__pycache__/data_interface.cpython-39.pyc b/datasets/__pycache__/data_interface.cpython-39.pyc
index 9550db1509e8d477a1c15534a76c6f87976fbec4..59584d31b16ee5ba7d08220a2174fb5e79e9aaed 100644
GIT binary patch
literal 10339
zcmYe~<>g{vU|_gBWkb>#T?U57APzESWnf@%U|?V<wqj&pNMT4}%wdRv(2P-xDGVu0
zIZV0CQOt}GF_tLS6owS$9JXBcD0Z+IOAbdaXA~!x&6>lN%N@m?%M--|=CkGS=JG}H
z<?=`IgZb<^0=a@wf?zgBj!>>}lrSSh3TFyej!3R(lqi_Zog<bj9wiQD^W;e6N=8Y7
z)$`^^<w{3MgV}sJGP$x*vS2oUj$E#MlsuR%kfV^R7^Rr26s44_9Hk5v6U<S`RgF?*
zWN>Fl5lRtmVMq~9WzAA+W{y&KXGjrA5p7{e5rv9qxHF`PrHHpMq=-XBG~F3eBvK?>
z7*ZrbBB_eaEKyqS3@K76(k%=r(y3a_%u(7YjKK_=GA}`)>Zi$gi^IR5Br`uRF-Mc}
z7K=}Qet{<AEq1re5-^kL7Ee)PUP^v$d~rceX2~rsm&B4ppZvs>)FMr$TkIfSacape
zw&2pF;?xpN##^kuexAW0nvA!&N{SNmiqrCoa*JOwGB7Y`GTst!OHC{(ElPDtOotl5
zcuUa3#m&^$(+6Zy4#Zr>TLNAI?&xAH0f{9UnvAzZ97{{`p$55DloTZ<m*f{|GTxH#
z%uPznNz6-5O#zuzoLUl`m|KvOTBON*i`mi1`Ic~EQZZP0a!z7#ac*i!Mt;gIKA0ec
z5AuMwTYgSTGAMMAF(}=E_{9+n3=F9ZptKgnlp@;Bkj9w8n!?t?8O5B!p2E?>5XF)r
z21+9>j8Uv9+$lUQ3{h+;yeWJw3{mVU{3!x03{e~@f+<2R3{jjZ!YLvx3{hO|3@i*$
z+`$Z*;<tD~j`z$fNi9lCOiq0XO6o5;7#J9exEUB2{F2!~Zh&HtFf#)K14yPgjDdln
zgrS*X0pmgjMur;38ishL8m1bCc;*`B8ish58kQP{c-9)$8ishb8nzmSc=j5G1sou?
zC7cVm7BXaUr!Yz~)Uu~A)v(WINMW7}(#4a)63n2<TJ<BkxHLC6v8XbZi%S6t9E;M6
zt>8S6U{Yd9az=b{W>u<!MrK}#jzT^xVOc3er=%7q7iEG4V)YbUQqvMkb4rR8O7a!V
zO!QEU$t%r`FV8Q^PAw`X$s7w*V^R`J5{pwy;`0)7Qx!CdONxkfkGeX-AaJPvV$`p4
z*RRY;%+n8Yc8)J7%FjwoF3~SaPL2mBrOdR<<U~*w(udkvtnZwdo0?OZpJ!->#MDbJ
zF7wmmy~SEwT98_Fiz7ZhGcU6wKK_<KNl{{EUOXcFZV5s}kb>eCe@SX_39=diumG|W
z9;p3bKi^`@E>A2<FTN!J%2x3psrcOdl+v8kB7O!2hFi?Z<tevV3lfV;if?g(WsCCj
zOEfudu@tA~q!r14(wr=akOL9&AVL8|D1tPzWaQ_ju4KH$84n49`1q9!znt_zqL6?B
zMPPA3Vsff}2&fqF_4LtoNlh(qFG@@+K?D=n+Duqt*DI(jl4f9FkOCF&BA`Uh!NtVJ
zh=3f7e2i6MC>cZ#7Kq89G6tjvgh6Z&b_SWm!T_o$f*Dpa`W0z0FfgoSEK&n`1jGgl
z6lsIJ#8Xg|T98<j3Jxg)kVYnuQ7lzbs1_nCO9m?eDPv$@0F@9RyTF!KFfcIGFs3kO
zGZhQeFs3k-G8BcBFw`()F{Uu5u=Fz3GL<luFxN0OGp4YLFf=pPGQ;>a%n~3zO9@L2
zQw?(sOFC0CQw%60L$WSo3R^Z)aSGJT5~!K%ATvRkI)y2iL6f8El!5|0Pw6To=a=S{
zDCCvqCZ!fB<fkcsQ*CN7A}{MIB$a07q$rf-CnqMA<|GzXBFhz|7J-U!1#p_wFH6kP
z2c<j(aD{+kNMb=jP9;bgwA@$FNJ&l0%u7vCNUbQy$xO~H$*EMx%u}e$FD+64#XTfF
zD&!X_Bo-HErh`;w<|%-3g(lqjx0us1OZ-AKnTk|F(ZmXN#q(J+LExqREf$bl+%!3g
zctCLqR>GW_S8|IvCpE7K6eYKK!8*Xv3aYBO%Mx?o93GHv7*CTGoG@;&7MCQZr`}=*
zxe%m-G4U2-5+tR7(+oJduokD5loptQ(uo)-0q_VgvN1CKXJZmz6krx(7hnSMs)SL~
z4p<LHN(Gha;6wv52oj++3|UMy3|Y)6j44dLEWr#bnf=@}nQw6+6%pLXIiI~aqcklo
zCsmWB2xMB3DJU=ySy&gu1-l(gfIR>$bU>a26%eUK;UK?&6EtI$9ICq!AqY`Qju(n-
zL5%>Y`#|A_RK#F&A0I59AXY{pxlRVvb*Ks{cAGgUun>hFHn%~_VnoozA~{YL)p76y
z2Jr@p>p+1GiY{>Ehk%MH#sv&1OexGQ;F4kiQwqyMMi+)=#)XWvj8$?a%q1)fSW{R*
zrO0fCxl9WgYZ#kBB?DVHL!JmD0|=I|r?53MHZevr<S`jAq%Z_CXtMj=V!*H#R8R$j
zTn#E`(-~?QVtHyAYZw<WEM%C-RLBy{08a0qW?XV&L1sx}PG(iACetk@J%d|}nP52x
z0ZvL>HaVHaCCT}@1$Idc3=E$^`Ju`PH5@?R2WK>BC8TGQlb@WJQ*5V)(0q$Az6x9w
z6{nUI>s1+e=9MMpWTq&9Oe#(-QE<!5NmcLz)#v#|3NT4$P-_k>stBsQvhp+Yia-_B
zEmm*=b&EYcwIs2mr05oBW^O@#QAti_(k)&H6Y7*(+@P3*ak#<~b4pWPi;D7#G?~B!
zLKIgrs5%GNst{kpb5W5J0|P@MC>%j~mw{1+k%^Iwk?TJbBh!CS-36l2GZT79pu`O*
zIe_BE8)RkzxJ|c!5mXvvF)w7yVgZ#0txS>(wV=|0bphK#hFa!KhFX>yCKrZSsajC`
zj=7AXNUMe+i?xPTlA(sBhSi3lP^*Sz0ecPOLPkbV!eUBe$YiKtS;E){>Ig7bomBA5
zD=taQOU?vGvqGXmZhlH?jzVd1W?s5NNk*zda(+=!YH>k+UJ6L0Ah9ShH?<_Ss2E;I
zgPMP7Mftf3ptgfv23SS0LP};bs9BU)RH>&BnO~}qoS3JOm{XjukeOFdTB4AcSE&GM
zp`e&jlCO|omReL^lv$FB#qKgt<bdiButks-4q9`~Pm|*oduCp7L27czEiO>(7o~$Z
zMFF6U4QgQAVgXrO1Zl}+r&bo<Vk<33Ni0d#WCKU|Ew22MjMSodkg6yyaBT@;vt{Oi
zlt;0GWDIXHCl;lnC5T(RnaE+81&T%iP<~*QVB}(CVdP-sU=(2FVyuz|C2Dv(Uk^o>
zCTo#6$gL6}0@S((H$(k28T~YQK+Q%_lQus77FRqp&Bn*y;)#zhEKSUT$bef>MS39J
zHXy<gL^y#6P)P$$vqh1h%9Iz{d`nKv$%&6g@`?sXF(_{rfzma|9mODDa4_+Saj<bI
zaxil+bFgr5f@LwfO6(y2f;?V;+*J}zWlm*IWl3d8Wldq7!wT*=v8AxL@J6wvh@`To
zGD|X~aHMd~VFJr=rEs_KMscKartr*R1b3fUqPU<uP$!BdiW|xUb){IMc)(pK-e3kz
z(ObOGuA?ucwF?TsmqrW>3@<_DZOBSiNJD}vCow5CC%z!DB%{a<<XwAEh_EInCTFB3
z!@L0U2`2*s14t*R)m#iJ-w1cJQb64=HgIPP)S+UC=STr{r#MqUoiHv?2_Da#0_ueE
zq_Ed8#Pg<b)G);Jr7+Ym#Pfs7?3J8;nGK8#3@<@;zXT146oJZumnT5(j0X`QPlad-
z6oJ~xkmv?AeTqQNL-Gw*W^Q77D!5B_i!%?>R8K9w#aohCln&||fm?*P_;XVeLA{Uk
z_?*<d^pcEQd}WD6@kxmYKB%dYl$c%|4{iwD;zjUNi%UQqhSZcIP%^m1la>ln18M(h
za)NUcIJnVVR+I_y1V~E}sNMr-#UgN`Llmu`oLU7+MxeTwfsqTTBY@}<urabRu`yPO
zp+q^nQ{cCfsmKBpt!zc9C8b4qDD8DnQ45M^P#3y*611pY$dJVZ?g%r1TG(X_MM)*h
zH4IrSpaK;XhgobXEWIqX%q8q495u}F);kM~U&8`!y|b2Zg4*sitnk)5R}DivcMW3-
zTNy)93zC`aH4Is7DU2x`y`c6zPYG``Ll$2RV+v<4sBq@;184n}OnycBpb!J~OI9+2
z1M?OOsFa2Tq#>wV!{k>4O66celLf4%s1&3L6xLDv;5rVRy20sTB_lXwH2I4_>AFZ6
zq)-J!@PP<KVghL{G6sp6fCzAl8Z@+%nFp@TZ*hUEXb77V-q?YpOHkpn5>ljsEh|z7
z8DR+`tU!b{h(L>Vke)tJtP6qi9wQ$U2csCH0HXk-7!wDh6jPNjYD|M$)GHZ_iWnGB
z8#$ovA1H=FC0;S8C5&iTgF~OO8C0#Zq%eY-*P2Xzn#|yQbBnzgGHO%=Dh@PRA<bz}
zZw@S7BnWacNU8`+yRZmU(9Q!n9#md1FjmQ-Iu@LJz(ECS1EdzA)cha^gF+YVU{I?J
z;b2HlpqDwAVI`BFCKJRZAV<KO`;{P9!3#00P6HMH3z3{AgX%N_ZD@iHE2;w}L1^n9
zlxq==#OfT7FP9)WM;6sN@Pq^jC`}e{HYfrYeV~#sM3d2PC0`M!Q&1EQ@_Y;^mopca
z6oJ#~EtaJG{2WNR2#Sg#P^yANB`AuDK!H&NazYWP<&L%9i(_D507ZK-sPyGvk`a*-
z;A7@vE&_>z$}Ldpgkw-K0^$_otnt{OH6Cb)5V^)<rAmzl9+gF^?utM)Hn_?N@hbu~
zPm@7og&-G#Fvv%sPyuE9;tX)XM`+xT)C#QzRITxWt2KUbwI%?n){0U<T|hxd;Q%V$
zu~rQr7eR_OQ0u8E8x*}cAR-q;<bjBMP;7G~B_?Ml<>#e>tDYiI(68horp9Rj8G*Oj
zSp`a&piTt-YKM&x(!qy>JiOK^Dg+fqpb!RyHVA_XB2XFuS2Z)hAr2a+MXPGkzy%Xh
zRa3)|1sd34tz|CZ097bR)eV#nuWq=|svGVahIk%mb<>4pF1WgZnG0%ql<+k(WbxO4
zDj-mW!wIQyii$yD1gdQ?YLX&w@PTTaqB4*oaGg^QidX^ks)QFLk5-j{b)(iNMWB3E
z1TMqCB?q`o1J1G4AOo;e5c@#!gQy^+;1z@nbCm$7VGWHDaP3f31Bwn%@d}Cx5C*Z~
zB^t;WL{SUsV1Wu+P0)Y~V-aZhU6UD7a)4b4E@W{P{bxbW0>?YFeGhd8!Ezte;{z2(
z;6^sGV|tk&Wj-UgIEA?c98@@60xH}tAh|>mDL@Dm^LX7+)CP)VXkdUN3a1l5ez}a~
z1n>|lEHFUB*r3J;v}k9A6zxbQmL{WLQ3ELE!EHe>0Zs$p{zx+e1H%fC(V$WcTrTs2
z%Vm%_M!5{~4=4(YL2Vb<q)jSwD!3m#hY2(k+rrY!7{v+|W1Yj40_u~1#Mr>oWE{7I
zof3;v!E-{0aP(Wrbc++z2u{n-NhyXjKEUw~N?p+I0w|Gz!l4-C2yp4A1M|~L$hbmL
zCrkiR1A+W^iyNF~<G}{s0QnPaE>o2}7H8-osa?r*iwmSFJ~J;RwSr_PL5Dhuda*c3
z0AUTp^gBrIQo-*ohz3n|w5qxY)W88()kWPP8$m5TNFact8%%(Lqo|#Mf#EvHI#311
z!N|lY#LUD97R88cPyq`H|Kc9xZl@$tw-db*mV$OW!If|dsLRP3#gWRA!kWS}hY8y0
z<OK8i=P;!RqzJZvI-%StLUS0wGj}XeJgE?#Xo^@1OB64ZC!QkF!V<*??vnBcGiXZR
z5_2xfFD?#C%!x;7=U^m#0+lr=S5UwAiLJ7(VTcz@0rf|Pz@#vk6agn=a3wBEok|>1
znSi1aTz^Aq0#Mm<3);s3CFxrnpkfQeDryBK7I0yJNFtzy#3N7=0f`f;u-Vukea9+!
z>@k9;5(njDP(=>T)NUA+I8r4G8k}QGVFvXR!IQ7-NOGX+9V}M^8WjMI&$2?RTu^OW
zR08rKs7lr3M6XTxKvHP6DY#Aog_kA^q~!zZ?ILxvz@Y>tz_|%jsr~~66(~0`F!C{r
zv9K{$$>0bIaJDb%0M(x0LI{+(poJSe96*Ci;K^nNP|exGPy!m0W?aAo8mk8NnVT6w
zqt~D@Y|z9rDE(+M`=RuEKt_YwMc~nB&>&MfXf&DwG#ZUG5)CdKG?|J(;a=ne3JI{U
zzy#Q99t;c&{EQ3?MV<@{3{__MyaXv+u?;|LvO~&bQIG{7TW+xyCFZ547J=#yaN@kh
z1|C%|F6sg)0b36yz-|B)WxR|G4B#T4fsu`|N{O&Lia^B`%D@CDp@BjKoY24l%#g)U
z!zjrR!&J*u%UlBLxH8o+frtKKJZ3PD1r)$Vpn67=wP-TPg`kKAJG^KDh&vHfg+r4H
zsHlfcj%DVhWG1H;M}a5Dpd4_JLkO_%KxMQb$akQ~XJF(5O~$Gd^dF=Gz}^M}83gtu
zG5vF<qFRs-z};SGw;b#xa3h3pw_F_La!^`hV64LBYJ%-Fc%u<ieIeRui2gaa8w~4}
zg901eZUc`cVC@Ej+G^5BE>p!3o}dB{(Pl!)KcIF7aqcRr2L&Lu#w#ci<dB@9hQk@?
z-EypLH%$Tb%Dxd~FSyO$1Y&^->>^N9K)U6iTvG%}T##lqu^se&1_p))pt1~9E^;u*
z$cXU^un4g6vG6e$f#fuGih@A4`GNY1Y$cTinR)48TW<*#RF>oyC1=FvWTt17<Ynfi
z-(o7riITw(*2~GyFDTYaN-Rz-0%gc31zbwfGE3q?ii<#bH%b6(yk2QZW=^plDB8di
z(O~aINubMt6%?nI6oJ|=MWFl|B?4AcmRX#cp9hmJE&`1OM{y%b6@i+-Q38<3DUdag
z7BzS#JW4RRw74Wc7iqMq2-F>m5`~Iq6{I7}fdcUsJ7_RGCo>6L^F;B$7jUH(6&Ha@
z<|rPp?Zwa$GEk<w#hjRwTm<Sv6$yhv8<a6_i5F+*q$U>S>48QlbK;9rb5g+*K}Dc)
zHA>hA7VL-yHFzQcJhoB<YRH2dc;Hxslt17i7EFLjqFWp`kc9(wptN5MD(*njDm;uF
dj694yOrXI+9uNzGnM4?QScDRU421MWI01I;6;=QM

delta 2455
zcmaDH@YqZ+k(ZZ?fq{YHhHOVtyaEHm<H>e1l1zsfCh8c}=W;}GFfybtq_E_0=5j@G
zf!VA%+_^kaJd6w}Y$@zHyt#Z)d|)<54u7sdlmJ)_XO3X5P?Qju&6OjZD-tCFW^?C=
z=88p$F*3L_r0}HhwlJjdrm|*<H#0{`xHF{irSP{fr0_#UB;7R`QUp>2TNqLVp+ZtA
zjKK_=LN7t?_0wd$#o=F2l9`{Em@~PbQFiivM)%25OqPtv69q*k&t<A*WZx{!T*}DE
zK6xU`6-KVfX{^dp!VC-yMIs<V6hw%D2yqZ0F?lg-EUOR$1B1Y1R<<x*O-8>WMFs|j
zm5fEQAbAiQEKsDxz`$^er=TdcAh9SlJ|(dvQEhVv+Z0AdnaQ#ovZ5Ry)iNM8%$a#5
zx0rKM^NN@!`*RpD%1&<JkP<^Ph_yJiq_jY1@&=AcRxBW`+8_eqOl1%kO^-lHQDSBu
z*y^18#FW$`-^n$c@%F|H3=EnqMPO5OK-%z|!dI4<gJed)WGSv#2V^sJL1rL=37bp!
zOHzwVkPHc)yq)VZqvGUlZfR*WM|0WaWEPhs=jRsKMNB@(-O4CF*`3Ez6%@EpT*-;K
zsX6g^iMgqeFb3-d1$2=l0|P_&=B+$kpm=oPlhJYmxe63dw^$O3(u<2EKtautomyFZ
zi><UEC9xz`lMNhYQIp&FbmYJ$feDatw|F!2ic1pnk~0%?GOJP(CSTyQ5a0uu$PXge
z!Ng<<e!VP^4Mi#-z7dEp2N4z^!V*M)Tvg-;vY9tNJ~=0`xHvgACnr80$@TIebs#&7
zco`TNKz0-hFfcH1F!70TuyHAJFmo_-uyAk|@lM{rZ#MacSe!GjucwboVo9QJeoARh
zY7y8mAZNW~Vqjnh@mtAS1Ws05If+TBIq?OFB^gB~AWKX^?qE$$OwLH194_AGv69m-
zvw@L;;Uy;n1H(&D<}L~Wu}**pux~&{hpZGV5(H@n`xInyQ5c8|H!^CnpP*QM6mLml
zQF>}gd|qj8Qfg5Ye{O1GUVKtwdVEf5UV2GJ6kl0l5eh%AG#A7#jt8gGC|(3VwYcOK
zOL1yy$}Nth#N_Ox{Jhkna0Uj3D4w)bkm|hRwEUvnm0aNH1ADf}5!FpaaUdsx^b{q4
zSY99klmd#t!3~e+B9NyF7#J9e6(-v%D6>|H`Fi?H-Y6l=vXZGtZ*s7Z>g2}~bLv-e
z`hkO4lgY106=akeBnyEoy~UE7pOOl(5Nww^NRG*`NCTu0L})S=>48cJfg*4iOM%3t
zK?Eq=A<hRGgk*#!NLC9(FoWW3@={^taBv)eVr3;H*@De1k^?Ey2N4D!!VpBDc@d<i
z1>{ALUl|zrm^c`v7zG#w7{!=47-c5Qi%2piPqq}1oqT{(%0ZLaFJvWC5h$q?nSmS)
zEho}J2E(Jg2umI>0ww)PlMjm|S;Ez=gzC;9tQ(YPrcRC%jW)?cvNn^jMo>PRG5M{i
zRy`{?j*3880i2LS{4^Q;R`M13gA54(5rLr4WG*f#DguS@EtaJG{2Xu+C<0}jB2Y*{
zibarLia<^(0vTTv1kwynG++W8z~C$!%)r0^a&Ivg0|PGu2a}A56dxZmA9GPKG`D66
zf$|Wtn~FjwGs??OE)>tzhyWQI3nCIh1lU-x1z@v5(KQbgZ{VakdA)=fOA-SE!{j(2
zDNv%cD@p-rN&^vDAOcN0$TOQj+Ck=EL_5=DM=4>(Y*@T&Xfpe`73F}8$psO(;`hMh
zeNq;a=cy<ugM3|-k4r5m=nhZzkTyf81?4iV=>`<S$0nba)(k=qDp1{82(khkC}0BY
zNpSF#FfcI80r?M<sgOb*B(5o4<N?y>GC5dHLqeb=zbH9Fue2mHr&teE4;O)QRT0Z%
zMs+D+R#3UZ4OUT>S)7@lSHw11PhHFxR6R!tCYKhM<mVz)Tt%Qfdy5@ZsODrQfirXz
zPf1ZCq)19FDlXy%8OQ@RusAWdASbn`h-dPBb;&?b@m>T<Fhw9IBGmw3kAeJji^C=t
fT&CMGf;<K)YdIJ>7<m{u7zG%4kRg*$f{-2nz3A|R

diff --git a/datasets/custom_dataloader.py b/datasets/custom_dataloader.py
index 02850f5..cf65534 100644
--- a/datasets/custom_dataloader.py
+++ b/datasets/custom_dataloader.py
@@ -41,7 +41,7 @@ class HDF5MILDataloader(data.Dataset):
         data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
 
     """
-    def __init__(self, file_path, label_path, mode, n_classes, backbone=None, load_data=False, data_cache_size=20, bag_size=1024):
+    def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=10, bag_size=1024):
         super().__init__()
 
         self.data_info = []
@@ -55,7 +55,6 @@ class HDF5MILDataloader(data.Dataset):
         self.label_path = label_path
         self.n_classes = n_classes
         self.bag_size = bag_size
-        self.backbone = backbone
         # self.label_file = label_path
         recursive = True
         
@@ -134,10 +133,6 @@ class HDF5MILDataloader(data.Dataset):
             RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)),
             transforms.ToTensor()
         ])
-        if self.backbone == 'dino':
-            self.feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
-
-        # self._add_data_infos(load_data)
 
     def __getitem__(self, index):
         # get data
@@ -150,9 +145,6 @@ class HDF5MILDataloader(data.Dataset):
             # print(img.shape)
             for img in batch: # expects numpy 
                 img = img.numpy().astype(np.uint8)
-                if self.backbone == 'dino':
-                    img = self.feature_extractor(images=img, return_tensors='pt')
-                    # img = self.resize_transforms(img)
                 img = seq_img_d.augment_image(img)
                 img = self.val_transforms(img)
                 out_batch.append(img)
@@ -160,23 +152,24 @@ class HDF5MILDataloader(data.Dataset):
         else:
             for img in batch:
                 img = img.numpy().astype(np.uint8)
-                if self.backbone == 'dino':
-                    img = self.feature_extractor(images=img, return_tensors='pt')
-                    img = self.resize_transforms(img)
-                
                 img = self.val_transforms(img)
                 out_batch.append(img)
 
-        if len(out_batch) == 0:
-            # print(name)
-            out_batch = torch.randn(self.bag_size,3,256,256)
-        else: out_batch = torch.stack(out_batch)
+        # if len(out_batch) == 0:
+        #     # print(name)
+        #     out_batch = torch.randn(self.bag_size,3,256,256)
+        # else: 
+        out_batch = torch.stack(out_batch)
         # print(out_batch.shape)
         # out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch
-
+        # print(out_batch.shape)
+        if out_batch.shape != torch.Size([self.bag_size, 256, 256, 3]) and out_batch.shape != torch.Size([self.bag_size, 3,256,256]):
+            print(name)
+            print(out_batch.shape)
+            out_batch = torch.permute(out_batch, (0, 2,1,3))
         label = torch.as_tensor(label)
         label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
-        return out_batch, label, name
+        return out_batch, label, name #, name_batch
 
     def __len__(self):
         return len(self.data_info)
@@ -184,7 +177,9 @@ class HDF5MILDataloader(data.Dataset):
     def _add_data_infos(self, file_path, load_data):
         wsi_name = Path(file_path).stem
         if wsi_name in self.slideLabelDict:
+            # if wsi_name[:2] != 'RU': #skip RU because of container problems in dataset
             label = self.slideLabelDict[wsi_name]
+            # print(wsi_name)
             idx = -1
             self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx})
 
@@ -195,14 +190,29 @@ class HDF5MILDataloader(data.Dataset):
         """
         with h5py.File(file_path, 'r') as h5_file:
             wsi_batch = []
+            tile_names = []
             for tile in h5_file.keys():
+                
                 img = h5_file[tile][:]
                 img = img.astype(np.uint8)
                 img = torch.from_numpy(img)
                 # img = self.resize_transforms(img)
+                
                 wsi_batch.append(img)
-            wsi_batch = torch.stack(wsi_batch)
-            wsi_batch, _ = to_fixed_size_bag(wsi_batch, self.bag_size)
+                tile_names.append(tile)
+
+            #     print('Empty Container: ', file_path) #/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/RU0248_PASD_jke_PASD_20200201_195900_BIG.hdf5
+                
+            if wsi_batch:
+                wsi_batch = torch.stack(wsi_batch)
+            else: 
+                print('Empty Container: ', file_path)
+                wsi_batch = torch.randn(self.bag_size,3,256,256)
+
+            if wsi_batch.shape[1:] != torch.Size([3, 256, 256]) and wsi_batch.shape[1:] != torch.Size([256, 256, 3]):
+                print(file_path)
+                print(wsi_batch.shape)
+            wsi_batch, name_batch = to_fixed_size_bag(wsi_batch, self.bag_size)
             idx = self._add_to_cache(wsi_batch, file_path)
             file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path)
             self.data_info[file_idx + idx]['cache_idx'] = idx
@@ -461,16 +471,19 @@ class RandomHueSaturationValue(object):
             img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
         return img #, lbl
 
-def to_fixed_size_bag(bag: torch.Tensor, bag_size: int = 512):
+def to_fixed_size_bag(bag, bag_size: int = 512):
 
     # get up to bag_size elements
     bag_idxs = torch.randperm(bag.shape[0])[:bag_size]
     bag_samples = bag[bag_idxs]
+    # bag_sample_names = [bag_names[i] for i in bag_idxs]
     
     # zero-pad if we don't have enough samples
     zero_padded = torch.cat((bag_samples,
                             torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+    # zero_padded_names = bag_sample_names + ['']*(bag_size - len(bag_sample_names))
     return zero_padded, min(bag_size, len(bag))
+    # return zero_padded, zero_padded_names, min(bag_size, len(bag))
 
 
 class RandomHueSaturationValue(object):
@@ -510,11 +523,11 @@ if __name__ == '__main__':
     train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
     data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
     # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
-    label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_no_other.json'
     output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/augments'
     os.makedirs(output_path, exist_ok=True)
 
-    n_classes = 2
+    n_classes = 5
 
     dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes, bag_size=20)
     # print(dataset.dataset)
@@ -528,15 +541,19 @@ if __name__ == '__main__':
 
     # print(len(dataset))
     # # x = 0
+    #/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/RU0248_PASD_jke_PASD_20200201_195900_BIG.hdf5
     c = 0
     label_count = [0] *n_classes
     for item in dl: 
-        # if c >=10:
-        #     break
+        if c >=10:
+            break
         bag, label, name = item
-        label_count[np.argmax(label)] += 1
-    print(label_count)
-    print(len(train_ds))
+        print(name)
+        # if name == 'RU0248_PASD_jke_PASD_20200201_195900_BIG':
+        
+            # print(bag)
+            # print(label)
+        c += 1
     #     # # print(bag.shape)
     #     # if bag.shape[1] == 1:
     #     #     print(name)
@@ -578,7 +595,7 @@ if __name__ == '__main__':
     #         o_img = Image.fromarray(o_img)
     #         o_img = o_img.convert('RGB')
     #         o_img.save(f'{output_path}/{i}_original.png')
-        # c += 1
+        
     #     break
         # else: break
         # print(data.shape)
diff --git a/datasets/custom_jpg_dataloader.py b/datasets/custom_jpg_dataloader.py
new file mode 100644
index 0000000..722b275
--- /dev/null
+++ b/datasets/custom_jpg_dataloader.py
@@ -0,0 +1,459 @@
+'''
+ToDo: remove bag_size
+'''
+
+
+import numpy as np
+from pathlib import Path
+import torch
+from torch.utils import data
+from torch.utils.data.dataloader import DataLoader
+from tqdm import tqdm
+import torchvision.transforms as transforms
+from PIL import Image
+import cv2
+import json
+import albumentations as A
+from imgaug import augmenters as iaa
+import imgaug as ia
+from torchsampler import ImbalancedDatasetSampler
+
+
+class RangeNormalization(object):
+    def __call__(self, sample):
+        img = sample
+        return (img / 255.0 - 0.5) / 0.5
+
+class JPGMILDataloader(data.Dataset):
+    """Represents an abstract HDF5 dataset. For single H5 container! 
+    
+    Input params:
+        file_path: Path to the folder containing the dataset (one or multiple HDF5 files).
+        mode: 'train' or 'test'
+        load_data: If True, loads all the data immediately into RAM. Use this if
+            the dataset is fits into memory. Otherwise, leave this at false and 
+            the data will load lazily.
+        data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
+
+    """
+    def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=10, bag_size=1024):
+        super().__init__()
+
+        self.data_info = []
+        self.data_cache = {}
+        self.slideLabelDict = {}
+        self.files = []
+        self.data_cache_size = data_cache_size
+        self.mode = mode
+        self.file_path = file_path
+        # self.csv_path = csv_path
+        self.label_path = label_path
+        self.n_classes = n_classes
+        self.bag_size = bag_size
+        self.empty_slides = []
+        # self.label_file = label_path
+        recursive = True
+        
+        # read labels and slide_path from csv
+        with open(self.label_path, 'r') as f:
+            temp_slide_label_dict = json.load(f)[mode]
+            for (x, y) in temp_slide_label_dict:
+                x = Path(x).stem 
+
+                # x_complete_path = Path(self.file_path)/Path(x)
+                for cohort in Path(self.file_path).iterdir():
+                    x_complete_path = Path(self.file_path) / cohort / 'BLOCKS' / Path(x)
+                    if x_complete_path.is_dir():
+                        if len(list(x_complete_path.iterdir())) > 50:
+                        # print(x_complete_path)
+                            self.slideLabelDict[x] = y
+                            self.files.append(x_complete_path)
+                        else: self.empty_slides.append(x_complete_path)
+        # print(len(self.empty_slides))
+        # print(self.empty_slides)
+
+
+        for slide_dir in tqdm(self.files):
+            self._add_data_infos(str(slide_dir.resolve()), load_data)
+
+
+        self.resize_transforms = A.Compose([
+            A.SmallestMaxSize(max_size=256)
+        ])
+        sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1")
+        sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2")
+        sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3")
+        sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4")
+        sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5")
+
+        self.train_transforms = iaa.Sequential([
+            iaa.AddToHueAndSaturation(value=(-13, 13), name="MyHSV"),
+            sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")),
+            iaa.Fliplr(0.5, name="MyFlipLR"),
+            iaa.Flipud(0.5, name="MyFlipUD"),
+            sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")),
+            iaa.OneOf([
+                sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")),
+                sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")),
+                sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine"))
+            ], name="MyOneOf")
+
+        ], name="MyAug")
+
+        # self.train_transforms = A.Compose([
+        #     A.HueSaturationValue(hue_shift_limit=13, sat_shift_limit=2, val_shift_limit=0, always_apply=True, p=1.0),
+        #     # A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=0, val_shift_limit=0, always_apply=False, p=0.5),
+        #     # A.RandomGamma(),
+        #     # A.HorizontalFlip(),
+        #     # A.VerticalFlip(),
+        #     # A.RandomRotate90(),
+        #     # A.OneOf([
+        #     #     A.ElasticTransform(alpha=150, sigma=20, alpha_affine=50),
+        #     #     A.Affine(
+        #     #         scale={'x': (0.95, 1.05), 'y': (0.95, 1.05)},
+        #     #         rotate=(-45, 45),
+        #     #         shear=(-4, 4),
+        #     #         cval=8,
+        #     #         )
+        #     # ]),
+        #     A.Normalize(),
+        #     ToTensorV2(),
+        # ])
+        self.val_transforms = transforms.Compose([
+            # A.Normalize(),
+            # ToTensorV2(),
+            RangeNormalization(),
+            transforms.ToTensor(),
+
+        ])
+        self.img_transforms = transforms.Compose([    
+            transforms.RandomHorizontalFlip(p=1),
+            transforms.RandomVerticalFlip(p=1),
+            # histoTransforms.AutoRandomRotation(),
+            transforms.Lambda(lambda a: np.array(a)),
+        ]) 
+        self.hsv_transforms = transforms.Compose([
+            RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)),
+            transforms.ToTensor()
+        ])
+
+    def __getitem__(self, index):
+        # get data
+        batch, label, name = self.get_data(index)
+        out_batch = []
+        seq_img_d = self.train_transforms.to_deterministic()
+        
+        if self.mode == 'train':
+            # print(img)
+            # print(.shape)
+            for img in batch: # expects numpy 
+                img = img.numpy().astype(np.uint8)
+                # print(img.shape)
+                img = seq_img_d.augment_image(img)
+                img = self.val_transforms(img)
+                out_batch.append(img)
+
+        else:
+            for img in batch:
+                img = img.numpy().astype(np.uint8)
+                img = self.val_transforms(img)
+                out_batch.append(img)
+
+        # if len(out_batch) == 0:
+        #     # print(name)
+        #     out_batch = torch.randn(self.bag_size,3,256,256)
+        # else: 
+        out_batch = torch.stack(out_batch)
+        # print(out_batch.shape)
+        # out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch
+        # print(out_batch.shape)
+        # if out_batch.shape != torch.Size([self.bag_size, 256, 256, 3]) and out_batch.shape != torch.Size([self.bag_size, 3,256,256]):
+        #     print(name)
+        #     print(out_batch.shape)
+        # out_batch = torch.permute(out_batch, (0, 2,1,3))
+        label = torch.as_tensor(label)
+        label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+        # print(out_batch)
+        return out_batch, label, name #, name_batch
+
+    def __len__(self):
+        return len(self.data_info)
+    
+    def _add_data_infos(self, file_path, load_data):
+        wsi_name = Path(file_path).stem
+        if wsi_name in self.slideLabelDict:
+            # if wsi_name[:2] != 'RU': #skip RU because of container problems in dataset
+            label = self.slideLabelDict[wsi_name]
+            # print(wsi_name)
+            idx = -1
+            self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx})
+
+    def _load_data(self, file_path):
+        """Load data to the cache given the file
+        path and update the cache index in the
+        data_info structure.
+        """
+        wsi_batch = []
+        tile_names = []
+        # print(wsi_batch)
+        # for tile_path in Path(file_path).iterdir():
+        #     print(tile_path)
+        for tile_path in Path(file_path).iterdir():
+            # print(tile_path)
+            img = np.asarray(Image.open(tile_path)).astype(np.uint8)
+            img = torch.from_numpy(img)
+
+            # print(wsi_batch)
+            wsi_batch.append(img)
+            
+            tile_names.append(tile_path.stem)
+                
+        # if wsi_batch:
+        wsi_batch = torch.stack(wsi_batch)
+        if len(wsi_batch.shape) < 4: 
+            wsi_batch.unsqueeze(0)
+        # else: 
+        #     print('Empty Container: ', file_path)
+        #     self.empty_slides.append(file_path)
+        #     wsi_batch = torch.randn(self.bag_size,256,256,3)
+        # print(wsi_batch.shape)
+        # if wsi_batch.shape[1:] != torch.Size([3, 256, 256]) and wsi_batch.shape[1:] != torch.Size([256, 256, 3]):
+        #     print(file_path)
+        #     print(wsi_batch.shape)
+        # wsi_batch, name_batch = to_fixed_size_bag(wsi_batch, self.bag_size)
+        idx = self._add_to_cache(wsi_batch, file_path)
+        file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path)
+        self.data_info[file_idx + idx]['cache_idx'] = idx
+
+        # remove an element from data cache if size was exceeded
+        if len(self.data_cache) > self.data_cache_size:
+            # remove one item from the cache at random
+            removal_keys = list(self.data_cache)
+            removal_keys.remove(file_path)
+            self.data_cache.pop(removal_keys[0])
+            # remove invalid cache_idx
+            # self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'shape': di['shape'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
+            self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
+
+    def _add_to_cache(self, data, data_path):
+        """Adds data to the cache and returns its index. There is one cache
+        list for every file_path, containing all datasets in that file.
+        """
+        if data_path not in self.data_cache:
+            self.data_cache[data_path] = [data]
+        else:
+            self.data_cache[data_path].append(data)
+        return len(self.data_cache[data_path]) - 1
+
+    # def get_data_infos(self, type):
+    #     """Get data infos belonging to a certain type of data.
+    #     """
+    #     data_info_type = [di for di in self.data_info if di['type'] == type]
+    #     return data_info_type
+
+    def get_name(self, i):
+        # name = self.get_data_infos(type)[i]['name']
+        name = self.data_info[i]['name']
+        return name
+
+    def get_labels(self, indices):
+
+        return [self.data_info[i]['label'] for i in indices]
+        # return self.slideLabelDict.values()
+
+    def get_data(self, i):
+        """Call this function anytime you want to access a chunk of data from the
+            dataset. This will make sure that the data is loaded in case it is
+            not part of the data cache.
+            i = index
+        """
+        # fp = self.get_data_infos(type)[i]['data_path']
+        fp = self.data_info[i]['data_path']
+        if fp not in self.data_cache:
+            self._load_data(fp)
+        
+        # get new cache_idx assigned by _load_data_info
+        # cache_idx = self.get_data_infos(type)[i]['cache_idx']
+        cache_idx = self.data_info[i]['cache_idx']
+        label = self.data_info[i]['label']
+        name = self.data_info[i]['name']
+        # print(self.data_cache[fp][cache_idx])
+        return self.data_cache[fp][cache_idx], label, name
+
+
+
+class RandomHueSaturationValue(object):
+
+    def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5):
+        
+        self.hue_shift_limit = hue_shift_limit
+        self.sat_shift_limit = sat_shift_limit
+        self.val_shift_limit = val_shift_limit
+        self.p = p
+
+    def __call__(self, sample):
+    
+        img = sample #,lbl
+    
+        if np.random.random() < self.p:
+            img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32
+            h, s, v = cv2.split(img)
+            hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1)
+            hue_shift = np.uint8(hue_shift)
+            h += hue_shift
+            sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1])
+            s = cv2.add(s, sat_shift)
+            val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1])
+            v = cv2.add(v, val_shift)
+            img = cv2.merge((h, s, v))
+            img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
+        return img #, lbl
+
+def to_fixed_size_bag(bag, bag_size: int = 512):
+
+    #duplicate bag instances unitl 
+
+    # get up to bag_size elements
+    bag_idxs = torch.randperm(bag.shape[0])[:bag_size]
+    bag_samples = bag[bag_idxs]
+    # bag_sample_names = [bag_names[i] for i in bag_idxs]
+    q, r  = divmod(bag_size, bag_samples.shape[0])
+    if q > 0:
+        bag_samples = torch.cat([bag_samples]*q, 0)
+    
+    self_padded = torch.cat([bag_samples, bag_samples[:r,:, :, :]])
+
+    # zero-pad if we don't have enough samples
+    # zero_padded = torch.cat((bag_samples,
+                            # torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+
+    return self_padded, min(bag_size, len(bag))
+
+
+class RandomHueSaturationValue(object):
+
+    def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5):
+        
+        self.hue_shift_limit = hue_shift_limit
+        self.sat_shift_limit = sat_shift_limit
+        self.val_shift_limit = val_shift_limit
+        self.p = p
+
+    def __call__(self, sample):
+    
+        img = sample #,lbl
+    
+        if np.random.random() < self.p:
+            img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32
+            h, s, v = cv2.split(img)
+            hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1)
+            hue_shift = np.uint8(hue_shift)
+            h += hue_shift
+            sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1])
+            s = cv2.add(s, sat_shift)
+            val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1])
+            v = cv2.add(v, val_shift)
+            img = cv2.merge((h, s, v))
+            img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
+        return img #, lbl
+
+
+
+if __name__ == '__main__':
+    from pathlib import Path
+    import os
+
+    home = Path.cwd().parts[1]
+    train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
+    data_root = f'/{home}/ylan/data/DeepGraft/256_256um'
+    # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+    output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/augments'
+    os.makedirs(output_path, exist_ok=True)
+
+    n_classes = 2
+
+    dataset = JPGMILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes, bag_size=20)
+    # print(dataset.dataset)
+    # a = int(len(dataset)* 0.8)
+    # b = int(len(dataset) - a)
+    # train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b])
+    dl = DataLoader(dataset,  None, num_workers=1)
+    print(len(dl))
+    dl = DataLoader(dataset,  None, sampler=ImbalancedDatasetSampler(dataset), num_workers=5)
+
+    
+    
+    # data = DataLoader(dataset, batch_size=1)
+
+    # print(len(dataset))
+    # # x = 0
+    #/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/RU0248_PASD_jke_PASD_20200201_195900_BIG.hdf5
+    c = 0
+    label_count = [0] *n_classes
+    print(len(dl))
+    for item in dl: 
+        # if c >=10:
+        #     break
+        bag, label, name = item
+        # print(label)
+        label_count[torch.argmax(label)] += 1
+        # print(name)
+        # if name == 'RU0248_PASD_jke_PASD_20200201_195900_BIG':
+        
+            # print(bag)
+            # print(label)
+        c += 1
+    print(label_count)
+    #     # # print(bag.shape)
+    #     # if bag.shape[1] == 1:
+    #     #     print(name)
+    #     #     print(bag.shape)
+        # print(bag.shape)
+        
+    #     # out_dir = Path(output_path) / name
+    #     # os.makedirs(out_dir, exist_ok=True)
+
+    #     # # print(item[2])
+    #     # # print(len(item))
+    #     # # print(item[1])
+    #     # # print(data.shape)
+    #     # # data = data.squeeze()
+    #     # bag = item[0]
+    #     bag = bag.squeeze()
+    #     original = original.squeeze()
+    #     for i in range(bag.shape[0]):
+    #         img = bag[i, :, :, :]
+    #         img = img.squeeze()
+            
+    #         img = ((img-img.min())/(img.max() - img.min())) * 255
+    #         print(img)
+    #         # print(img)
+    #         img = img.numpy().astype(np.uint8).transpose(1,2,0)
+
+            
+    #         img = Image.fromarray(img)
+    #         img = img.convert('RGB')
+    #         img.save(f'{output_path}/{i}.png')
+
+
+            
+    #         o_img = original[i,:,:,:]
+    #         o_img = o_img.squeeze()
+    #         print(o_img.shape)
+    #         o_img = ((o_img-o_img.min())/(o_img.max()-o_img.min()))*255
+    #         o_img = o_img.numpy().astype(np.uint8).transpose(1,2,0)
+    #         o_img = Image.fromarray(o_img)
+    #         o_img = o_img.convert('RGB')
+    #         o_img.save(f'{output_path}/{i}_original.png')
+        
+    #     break
+        # else: break
+        # print(data.shape)
+        # print(label)
+    # a = [torch.Tensor((3,256,256))]*3
+    # b = torch.stack(a)
+    # print(b)
+    # c = to_fixed_size_bag(b, 512)
+    # print(c)
\ No newline at end of file
diff --git a/datasets/data_interface.py b/datasets/data_interface.py
index 056e6ff..efa104c 100644
--- a/datasets/data_interface.py
+++ b/datasets/data_interface.py
@@ -2,15 +2,25 @@ import inspect # 查看python 类的参数和模块、函数代码
 import importlib # In order to dynamically import the library
 from typing import Optional
 import pytorch_lightning as pl
+from pytorch_lightning.loops.base import Loop
+from pytorch_lightning.loops.fit_loop import FitLoop
+
 from torch.utils.data import random_split, DataLoader
+from torch.utils.data.dataset import Dataset, Subset
 from torchvision.datasets import MNIST
 from torchvision import transforms
 from .camel_dataloader import FeatureBagLoader
 from .custom_dataloader import HDF5MILDataloader
+from .custom_jpg_dataloader import JPGMILDataloader
 from pathlib import Path
 from transformers import AutoFeatureExtractor
 from torchsampler import ImbalancedDatasetSampler
 
+from abc import ABC, abstractclassmethod, abstractmethod
+from sklearn.model_selection import KFold
+
+
+
 class DataInterface(pl.LightningDataModule):
 
     def __init__(self, train_batch_size=64, train_num_workers=8, test_batch_size=1, test_num_workers=1,dataset_name=None, **kwargs):
@@ -109,7 +119,7 @@ class DataInterface(pl.LightningDataModule):
 
 class MILDataModule(pl.LightningDataModule):
 
-    def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, backbone=None, *args, **kwargs):
+    def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=50, n_classes=2, cache: bool=True, *args, **kwargs):
         super().__init__()
         self.data_root = data_root
         self.label_path = label_path
@@ -124,27 +134,29 @@ class MILDataModule(pl.LightningDataModule):
         self.num_bags_test = 50
         self.seed = 1
 
-        self.backbone = backbone
         self.cache = True
         self.fe_transform = None
+        
 
 
     def setup(self, stage: Optional[str] = None) -> None:
         home = Path.cwd().parts[1]
 
         if stage in (None, 'fit'):
-            dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes, backbone=self.backbone)
+            dataset = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes)
             a = int(len(dataset)* 0.8)
             b = int(len(dataset) - a)
             self.train_data, self.valid_data = random_split(dataset, [a, b])
 
         if stage in (None, 'test'):
-            self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, backbone=self.backbone)
+            self.test_data = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, data_cache_size=1)
 
         return super().setup(stage=stage)
 
+        
+
     def train_dataloader(self) -> DataLoader:
-        return DataLoader(self.train_data,  self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+        return DataLoader(self.train_data,  batch_size = self.batch_size, sampler=ImbalancedDatasetSampler(self.train_data), num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True, 
         #sampler=ImbalancedDatasetSampler(self.train_data)
     def val_dataloader(self) -> DataLoader:
         return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers)
@@ -187,13 +199,92 @@ class DataModule(pl.LightningDataModule):
         if stage in (None, 'test'):
             self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, backbone=self.backbone)
 
+
         return super().setup(stage=stage)
 
     def train_dataloader(self) -> DataLoader:
-        return DataLoader(self.train_data,  self.batch_size,  num_workers=self.num_workers, shuffle=False,) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+        return DataLoader(self.train_data,  self.batch_size, shuffle=False,) #batch_transforms=self.transform, pseudo_batch_dim=True, 
         #sampler=ImbalancedDatasetSampler(self.train_data),
     def val_dataloader(self) -> DataLoader:
-        return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers)
+        return DataLoader(self.valid_data, batch_size = self.batch_size)
+    
+    def test_dataloader(self) -> DataLoader:
+        return DataLoader(self.test_data, batch_size = self.batch_size) #, num_workers=self.num_workers
+
+
+class BaseKFoldDataModule(pl.LightningDataModule, ABC):
+    @abstractmethod
+    def setup_folds(self, num_folds: int) -> None:
+        pass
+
+    @abstractmethod
+    def setup_fold_index(self, fold_index: int) -> None:
+        pass
+
+class CrossVal_MILDataModule(BaseKFoldDataModule):
+
+    def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, backbone=None, *args, **kwargs):
+        super().__init__()
+        self.data_root = data_root
+        self.label_path = label_path
+        self.batch_size = batch_size
+        self.num_workers = num_workers
+        self.image_size = 384
+        self.n_classes = n_classes
+        self.target_number = 9
+        self.mean_bag_length = 10
+        self.var_bag_length = 2
+        self.num_bags_train = 200
+        self.num_bags_test = 50
+        self.seed = 1
+
+        self.backbone = backbone
+        self.cache = True
+        self.fe_transform = None
+
+        # train_dataset: Optional[Dataset] = None
+        # test_dataset: Optional[Dataset] = None
+        # train_fold: Optional[Dataset] = None
+        # val_fold: Optional[Dataset] = None
+        self.train_data : Optional[Dataset] = None
+        self.test_data : Optional[Dataset] = None
+        self.train_fold : Optional[Dataset] = None
+        self.val_fold : Optional[Dataset] = None
+
+    def setup(self, stage: Optional[str] = None) -> None:
+        home = Path.cwd().parts[1]
+
+        # if stage in (None, 'fit'):
+        dataset = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes)
+        # a = int(len(dataset)* 0.8)
+        # b = int(len(dataset) - a)
+        # self.train_data, self.val_data = random_split(dataset, [a, b])
+        self.train_data = dataset
+
+        # if stage in (None, 'test'):,
+        self.test_data = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes)
+
+        # return super().setup(stage=stage)
+
+    def setup_folds(self, num_folds: int) -> None:
+        self.num_folds = num_folds
+        self.splits = [split for split in KFold(num_folds).split(range(len(self.train_data)))]
+
+    def setup_fold_index(self, fold_index: int) -> None:
+        train_indices, val_indices = self.splits[fold_index]
+        self.train_fold = Subset(self.train_data, train_indices)
+        self.val_fold = Subset(self.train_data, val_indices)
+
+
+    def train_dataloader(self) -> DataLoader:
+        return DataLoader(self.train_fold,  self.batch_size, sampler=ImbalancedDatasetSampler(self.train_fold), num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+        # return DataLoader(self.train_fold,  self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+        #sampler=ImbalancedDatasetSampler(self.train_data)
+    def val_dataloader(self) -> DataLoader:
+        return DataLoader(self.val_fold, batch_size = self.batch_size, num_workers=self.num_workers)
     
     def test_dataloader(self) -> DataLoader:
-        return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers)
\ No newline at end of file
+        return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers)
+
+
+
diff --git a/models/AttMIL.py b/models/AttMIL.py
new file mode 100644
index 0000000..d1e20eb
--- /dev/null
+++ b/models/AttMIL.py
@@ -0,0 +1,79 @@
+import os
+import logging
+from argparse import ArgumentParser
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models as models
+
+import pytorch_lightning as pl
+from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
+
+
+class AttMIL(nn.Module): #gated attention
+    def __init__(self, n_classes, features=512):
+        super(AttMIL, self).__init__()
+        self.L = features
+        self.D = 128
+        self.K = 1
+        self.n_classes = n_classes
+
+        # resnet50 = models.resnet50(pretrained=True)    
+        # modules = list(resnet50.children())[:-3]
+
+        # self.resnet_extractor = nn.Sequential(
+        #     *modules,
+        #     nn.AdaptiveAvgPool2d(1),
+        #     View((-1, 1024)),
+        #     nn.Linear(1024, self.L)
+        # )
+
+        # self.feature_extractor1 = nn.Sequential(
+        #     nn.Conv2d(3, 20, kernel_size=5),
+        #     nn.ReLU(),
+        #     nn.MaxPool2d(2, stride=2),
+        #     nn.Conv2d(20, 50, kernel_size=5),
+        #     nn.ReLU(),
+        #     nn.MaxPool2d(2, stride=2),
+
+        #     # View((-1, 50 * 4 * 4)),
+        #     # nn.Linear(50 * 4 * 4, self.L),
+        #     # nn.ReLU(),
+        # )
+
+        # self.feature_extractor_part2 = nn.Sequential(
+        #     nn.Linear(50 * 4 * 4, self.L),
+        #     nn.ReLU(),
+        # )
+
+        self.attention_V = nn.Sequential(
+            nn.Linear(self.L, self.D),
+            nn.Tanh()
+        )
+
+        self.attention_U = nn.Sequential(
+            nn.Linear(self.L, self.D),
+            nn.Sigmoid()
+        )
+
+        self.attention_weights = nn.Linear(self.D, self.K)
+
+        self.classifier = nn.Sequential(
+            nn.Linear(self.L * self.K, self.n_classes),
+        )    
+
+    def forward(self, x):
+        # H = kwargs['data'].float().squeeze(0)
+        H = x.float().squeeze(0)
+        A_V = self.attention_V(H)  # NxD
+        A_U = self.attention_U(H)  # NxD
+        A = self.attention_weights(A_V * A_U) # element wise multiplication # NxK
+        out_A = A
+        A = torch.transpose(A, 1, 0)  # KxN
+        A = F.softmax(A, dim=1)  # softmax over N
+        M = torch.mm(A, H)  # KxL
+        logits = self.classifier(M)
+       
+        return logits
\ No newline at end of file
diff --git a/models/TransMIL.py b/models/TransMIL.py
index 69089de..ca2a1fb 100755
--- a/models/TransMIL.py
+++ b/models/TransMIL.py
@@ -44,23 +44,23 @@ class PPEG(nn.Module):
 
 
 class TransMIL(nn.Module):
-    def __init__(self, n_classes):
+    def __init__(self, n_classes, in_features, out_features=384):
         super(TransMIL, self).__init__()
-        self.pos_layer = PPEG(dim=512)
-        self._fc1 = nn.Sequential(nn.Linear(512, 512), nn.ReLU())
+        self.pos_layer = PPEG(dim=out_features)
+        self._fc1 = nn.Sequential(nn.Linear(in_features, out_features), nn.GELU())
         # self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU())
-        self.cls_token = nn.Parameter(torch.randn(1, 1, 512))
+        self.cls_token = nn.Parameter(torch.randn(1, 1, out_features))
         self.n_classes = n_classes
-        self.layer1 = TransLayer(dim=512)
-        self.layer2 = TransLayer(dim=512)
-        self.norm = nn.LayerNorm(512)
-        self._fc2 = nn.Linear(512, self.n_classes)
+        self.layer1 = TransLayer(dim=out_features)
+        self.layer2 = TransLayer(dim=out_features)
+        self.norm = nn.LayerNorm(out_features)
+        self._fc2 = nn.Linear(out_features, self.n_classes)
 
 
-    def forward(self, **kwargs): #, **kwargs
+    def forward(self, x): #, **kwargs
 
-        h = kwargs['data'].float() #[B, n, 1024]
-        # h = self._fc1(h) #[B, n, 512]
+        h = x.float() #[B, n, 1024]
+        h = self._fc1(h) #[B, n, 512]
         
         #---->pad
         H = h.shape[1]
@@ -83,22 +83,22 @@ class TransMIL(nn.Module):
         h = self.layer2(h) #[B, N, 512]
 
         #---->cls_token
+        print(h.shape) #[1, 1025, 512] 1025 = cls_token + 1024
+
+        # tokens = h
         h = self.norm(h)[:,0]
 
         #---->predict
         logits = self._fc2(h) #[B, n_classes]
-        Y_hat = torch.argmax(logits, dim=1)
-        Y_prob = F.softmax(logits, dim = 1)
-        results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat}
-        return results_dict
+        return logits
 
 if __name__ == "__main__":
     data = torch.randn((1, 6000, 512)).cuda()
-    model = TransMIL(n_classes=2).cuda()
+    model = TransMIL(n_classes=2, in_features=512).cuda()
     print(model.eval())
-    results_dict = model(data = data)
+    results_dict = model(data)
     print(results_dict)
-    logits = results_dict['logits']
-    Y_prob = results_dict['Y_prob']
-    Y_hat = results_dict['Y_hat']
+    # logits = results_dict['logits']
+    # Y_prob = results_dict['Y_prob']
+    # Y_hat = results_dict['Y_hat']
     # print(F.sigmoid(logits))
diff --git a/models/__pycache__/AttMIL.cpython-39.pyc b/models/__pycache__/AttMIL.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ee3af4e0d05559e5bc7a4d8ed62481a3254c9f1
GIT binary patch
literal 1551
zcmYe~<>g{vU|?u4o0s&IgMr~Oh=Yt-7#J8F7#J9eTNoG^QW#Pga~Pr^G-DKF3PTE0
z4pT036f+}8j5&uTmo<tN%x27Ci(&(-XV2w`;(&^A=5jDGFf!zFMR7w#cyf88c%f{*
zDE<_N6qX!;T)`+oMh16=6xI~B7KRkIRE{j6X67j26vki%P4<@{_xWiu-r{pCN-xb#
z%_|8=EGkYd(qz2F?O&9VT9lgNl9^nh$#_fFCo?ZKu_!#TD7Uo0IlnkFFV!(GFEueI
zGcVmIC>dl9G6wk-#4lE3U|>jP0QoP9DTS?_A&oJGDTTR(Gm1HdC55$xA&RA)frTN8
z737**Y>p)*zMeiW8JHLtl9@sJp%}zwV_;wa(Z$mk7#K<zN*HSxnwb_bEo5M1s9~yM
zh-Xe=tYL^}Ndd80Yd~za67~fgCF~10YZ$VaYZ$Uv7BZEvFW_3pun?q%8>=i2RF)U2
zkEMnoixsS%4^2I*2ty4+7TZFmg&@25Qy797G@1QgHZU+SykulxV9*q}#adiikXls4
z%D}*Iiz7ZhGcU6wKK>S?&n-rmTa4bfIP>C@a}tY-Q;Tmg<>lSt3QjF7P0cIGOw75(
z1_|z4EFp<`8Mn9-OG-cz`FZhSx7dR-({uAPQ;<bNZwVrll&5B<XOtA*;sP6<nU<Ma
zq{)1Xr8qSwt%wiol(f{ulG38o;+2fIIO8F1h>u^%@XJC!BR@A)zcME=Prp3BD7&~I
zF*#K~q$n}3*w@oX*CjQzz`ZCjtwcXJKP5G%SRWEydIgn5pp-8FO4~}H7-QsPWP?CC
z77+U%8!L<!Vq#$w`Cr8k)d~xgWG0Y4DCT5fU;u@gGswU!1_p*2#sv%u85S_6FfL@Q
zWvXFVz*NJukg=AzhG79Sn8i}Vuz;n8xrRlAVIgBJYYl5HTM6p|wi?zHrWD47%(bi~
z>{%RO9&;~KEo%wu0?r!N8m5Izwd^$v3%F|77c$mz)NlkdXtMYffg-Pn0~CXdx0q8h
zb2Yhdv8LtZCzjk|FD@)iO|42T5@ujvC=vw`VxRzJEy*uR&bY-{0t%^u{NmJGjBdBs
zi}TY;auX|VG3DkKiGz&eERq1JW~{iy=n=*27#{{ALT@oTMzQ9Xmc%>WV)VVmmXn{J
zSyEgi12P;;$b$5+r{x!wCl;lEgG2-rPE34^T#Q_de2k!w<zwVw<YKDggN7<pu_j{?
zDAQ=Nff8R{Vs2`D{4K8d_}u)I(i{+*CqBNgG%*Jv1CGKXZIETW&@_>pnv)YBkK`07
zkfT8E0VipGQ1pN@ItL@4F<3-X<Q7wY@h$e8{Pgt9y!2ZfiACuJ-~t5{y|=iN^K)`i
zlS?x5^NNc=S-D6NWE}@M^z`!bia;ht3B$O0X{C8!#fdq$xDf0jWd;U@C?T+_vdrSl
z{5(BKa=gU`r3DKr!Lso=pp1|QatTvGPLx7HPJB{+PD!y|enClQZe~?#QL$c5QG9W7
yMk**oii$v<xy9xS4j-_Gko=B_3l1Ad2-<-XN--!hI2bt?c^G+^K%@|aW(EMPkda0J

literal 0
HcmV?d00001

diff --git a/models/__pycache__/TransMIL.cpython-39.pyc b/models/__pycache__/TransMIL.cpython-39.pyc
index cb0eb6fd5085feda9604ec0f24420d85eb1d939d..a329e23346d376b150c7ba792b8cd3593a128d95 100644
GIT binary patch
delta 1124
zcmZpWS}4hz$ji&cz`($;|LDpj_Kmy`85yG{zh^wd$UAutQ?Wch0|P@59|Hq}CR33B
z0|Ub?#)={lkQkT{o$SY)tjP=Fi-RP2<KvTa5{rwIQ*(0S<B=2!fs}}DzQxSLD8K_^
z%Yg{6M(N4lSd^qdW*2dQ1i?xnrh~Y=lZ{xPFe*$IXVc<V1qrBu2=&Q6Y-(&^O=6QP
z*h~b#8o-1oNM3yM9yTLa2@q2fL`Z=Muyr7J7RiFRVAWs(Y$aG9$XOuCVipDl1`b9(
zwj#O7ADKlayR*O1(Ew>u00o6;5!esrAOWzUAZ@oe3-XKOa}q04i;AQt$8$(3ySw^?
z-eQSQOE$d48IV|%n44OXT2!P93Mkg1#JrTeBGbvUIaG9HK<0qL7_8U`8q!7fAYri6
zz>WmFk9YEG4x>;eP<%l#DDK!97#N%x7#NB_FfcIGFfL$N$WY5v!?1v{hG`*VEmI9s
z7E=nNBttE833Ca{0@f0i1#C6U3mF$OHZvkbCg*TUGqY!LOzz^8t!GSO>ScoJWR_&8
zWv^k#;#|O0!=A#D!m^Ng0rx_NT8<L-ERGtE8m2U+6qtI31w2qY85Z!?Fx9Zsur6c<
zyM+%f&JSXxFr+ZFGNmx5F-bz)DgY9#<*eZhX3%6+^(z7ebP+hn85tNDUjG08|G%c-
zE!MQ0{KS$X8&J?_PVVKB<OT(2ksgR|I(Z|Pc)bOP=?Ee~AzI`FiZ3p3(k&7Lg*;<M
zkpTk(LljfIM-)?hcobJ+N=kfAYF>ItMv)On0atQPaePUBc4}VnEw-He^vsfCq=XEP
z1yC9SIiVO7?+i>Fj9koIj9QF5j2g`KT#O=&Ld;yue9T;IV$58ORRSSJiFw7oo<4eM
z`9<Z4MJYvQAlHH86ik4l7;G*mL4mviO?^DPMIez|9P#nFiJ5uv@tVxwV96{HXJ9A-
zrO+ZakoOqDNkx<K7E4N^W=Ud^FUT8gk?{pZ`AN4}BjYm?OEe`R>5ii$zbH9FFE6i1
zda@*s3ePP@w_B`vrMU%_MQ)S5c{HU!J_6fX1kOBmAg6kP2>;1_JmPv$JVmL+r8y<V
b@hO?fC19J10zlG%AOh?i1R*;46ptAIoyY3q

delta 1235
zcmZ1|*&xN6$ji&cz`(#@|1>|Ta3k+SM#fu{-!q<J<eR*QsaRfsfq|ijpMilvlc`9M
zfq~%`V?~iDNDNGfP4;6>*5m{6B|wtA@$tzyiN(dqsX00E@kk1VK}y6n-(u!r6yOE1
z<v|2kqs-)QEJ{)!vx_)Df?%Z((?MLm$wsVC7!@arvuSawfdte+gvMkaHZ?Y|Ch^G?
zY$gI=4PZhHBrh>}51WyzB#0>mBBVhC*gB9qi{wCDuxc;?wi2ul<SbbR28Lo51_lNW
zMn1M8`N<!dMJBtmztPbIX;K6Qg=i7j4;CN+u%RGrw>S&(i{o<=D^rV#WG2URNGb=V
z`h-TY#HS@2-r@{MEK1BxElDjZ(gOt)Yf)leN?wuK<k=i5Iv@+dVSEdu*a#ZZMfM=$
z!A=7^66`*{$*(z#av4GK1;rqjurn|)c!OLm#K^!<!<fP-$*_Q7Aww-w4O13l3X>#5
zEprJ|3G)J$66OW0HOvbc7cw?8Vu~;^)UwpF*09vDiZIkLOEA>3m9S;8PoBUjThElj
z+{*+tk42K9mc51{i(>(24SNb}3hP4V1zZamYB@lvYdC6{(wI_U>KPVr*Dx()tYxZU
zSin=m1TuReGnmH<7v}@9QW#PgTA5N<(m<{SspbcX)^e7xWeL=9rb7+nDiO>Qs^Nt4
zxNA6TxNEpl*yb?T^3?DIGib8=-C{{eEJ-W^MN<(t!a!m5^8f$;|23I!vE}5aXO<M-
zVvCF~D9TT|#Tpr(kyxTBe2X<LCqJ>INNaK+mt-s`f{OG(fo}$iDz?;$g2cR(TP(??
zDTzf6AaP3&;Rqr?Iitu4tSqr8JvXu97NgrO_Tv1slH9}!O<qWrVaqO0EJ`oF#h6iK
zIC%q?M7=Rc16OiRaePUBc4}U6kt@h7cM#zQlHw^!EiTO|DUMIcOfEsnMc^a=O29Ut
z$W{WS045$5E=Daz9!3pjE=CbXA!aUSK4vavF=j4CK1M!99>ywxkfOxAVqZ@my|nzI
z^2DN)B3pBiP2fZYCctS7Y&IxGgS-z-{ye-zAdy=f@$tEdnR)T?n#@Itle4*{)r*`#
zMu43Sa&ZyJ?V6I1w9QeHUzD7omzP&0GkG(&3SUtG$WN?!rMU%_MP8F1b8Ciy@(bAL
zB5)S90~zKAB0$F8V$ID@NzJ*%T2PdkS8|IbwJb5GC=?`%Y9CU-fW09$Ih)4}0Bh9^
Ae*gdg

diff --git a/models/__pycache__/model_interface.cpython-39.pyc b/models/__pycache__/model_interface.cpython-39.pyc
index 0bf337675c76de4097e7f7723b99de0120f9594c..e9d22d7ddccbb2a3d2ab59b98e066f9ac5d42285 100644
GIT binary patch
literal 14054
zcmYe~<>g{vU|@JOWkZs2C<DV|5C<8vFfcGUFfcF_w=ps>q%fo~<}gG-XvQceFrPV!
z8BDW8v4CmTC{{4d7R3&xIifgI7*d#WxN^CpxEVounR9q@d87EiY{ne^DE<_N6qX!;
zT)`+ouo!EOP_A&4FqqAjBa$l`B?@M<=ZNKsM~Q>k961uXl2MXiHfN4Ru2hs%u5^@i
zu1u5+Sd1}8HcA$3zg(_-lsr^SAy=M>fsr9sF-i$4q8z21!jQt1qmru{r3x0~&XLVk
zi&6u#8FSR5G*Wm{cyly!wW74ZVthH;xjIogU^ah_ZmwRG9+)kVqn~RKWsqwaWeDaA
z<{0G~M;YguM45p3LOG_nW>IEfws4Mlu0@mum@SfHnQIkg#mL~!kRqBQ*20h?ma36u
z-OL<i<Ia#Go+8o0kRp*PpJm(39A%fn7|ft4`4SXEewvK8SOOADGBg=)u@zJn<m8uV
zGTvhK%uP&B)nvRSkX)3SSdto_Ur>^nn^~1wq{(=T8!DEQUtFxocuP3BD8INkJ~gkT
zD8HaGz9hdW8DyBGV{v6}ZfZ$UX0lIyadB{FUV2WdPhwJPjwbUhDKH~AH$T55BQr1E
z8DxSh#DH63!TD(=A&EulsU;}l{9eT=$O0gj6qh8H#1|*$7o~z+!U}ffOGX9;22I9W
zoMo9M@x{4`IXNJ&<>sfP=71Ej6{QyErIu(i-4byuEy;IFO)M!bN_DL$DN0N($uGLa
z5tbR^3sS7fbc@}+C^5y^(f1aQYi4?C9+aoacuTN2BfmU8IWadrKQBHL8dhS?If=!^
znQ57+MgFBF1*K3=Xfod7@ky*qEdpy#%uNObJ2Ga5a*Ep+7#LC+q8L*cqL|tl(il@1
zQ<z#fqL@>dQ&?IUqF7Q`Q`lM<qF7VdQ#e`}qS#V6Q@C0fqS#ZoQ+Qe!qBv4`Q}|jK
zqBv9dQv_NVqPS87Q-oR=qPSCpQ$$)AqIgo|Qp8fkTNtBwQzTL(TNtAFQlwI(TNtAF
zQ)E(PTNt7Q+8J0Fq6C8(H05vcfdj!auOzi7EipMY8I*;g4rLHvU|`^7U|`?|WoZ@`
z28I%b62=;aW~K#93mF(0Y8VzULunQeUBX(!)XZ4J5YJY_T*DC0Uc;EhQNs|=QNmfn
zkj2r=$jDH`lENs-P{R<<Rl;4vkj2%^n8lOAB*~D%oXu3!RKkl8ZDy=xO<}2FO<}TO
zNMT*TSHrpxWHMI@TMBy%LkT~`L>C6IeF7j}3P(1}1jZtf62UAXFrPDptCy*UAzr9N
zaDi}%V3r72lp7>kBA6u#=JE70mI!8vf!VxJ8SxZeFrTlNsYEbK0);1u;H9%IWUOV&
z3jx`e!U2jSRNGQGQs8!_aHK$ON&&?l7uaO+6y6k$6u2!Ab~-D_)DpogDX<%OdO^OH
zE)mR<LE_8Suw}`CVjnCk59SG`2!MGCH7qGYk_=gjH4O1eDZ(`j@yazUDI(xdS4k1A
zVTf0S@KeOHnI<q6eJD{|pq?TQ4qc5BO*mU?0%P70s7i@!rU^_%c_rFeIw_JVQoW3|
z>?OQ4><e@kGBh)$Fr-MgGD$Mja+K(0>DO@dG1PL_a4s-d$l$^d%TvphB2&W^Z#aRe
z(4a(bfqo5ViQWRE6xoH0CB`Xo&5R{_3rs++%2Hj(RHC=QbRok+)*6O*knL_z+vUNw
zKPl19GK1RAUBX+#4YyqZ&32v|9xS%61KZAnWV<5NcBL8yknIauK=$TcfZD2@%`}0z
z$f3kM%OcA%MI}WQY&UZ)Zw)Vo>CA;7b4^mzASTuDW?98sgYDym`U#W{wem`!W~+nE
z-jX6y%Lh_bqL*cp!jz(sqM4%A%Uq(jz;+=+iC&gn3P%b@igqs(SY9JV2dBJo3YuRa
zW-R0c`KX35%RWUHY~zy}hAf9H$1JCG##;UoXArN3KSi>JA<HF2uNUNh*E~=<bI5W-
zRgt3K3(C>1DGb33ng)Kim~-;eUxLcBD3+4s+@dHh5E);VS(KRbi%~xrR5HWL5GDo&
z237_JhG0+`BF4bLkj_xU5X)B!N;eECjB}VKG8M7}GpuAP0_oCZy2YBDUz%5Pi#<8N
zs3<kLWF^B&rdv#U2Dcb9ixe0b7{J6YNBxZa+*JL_oWwl+^8BLg;)2BFRQ-^m#Jpl(
zPaj>E)YJm^qQtZkeMo%(q2e=PrJP<t<t;9ooXp~q<ow(MJ68q<hR+}eS81b`hI;Yw
znR%Hd@$q^#Ir+(nImLE*2wg8hUc1Fno|>7SQIeXX$#jc5uQWG48B{x@7T;n|EK4m)
zOi#VVnwyzil&{HHq{+a*aEmoJ5yWFE(gJw~EW}clm=h1Gyb6j^(=sb=v6iG3m&D&<
zNl7e81c%Bkj-<ro?4<m>)RzoQ44@h!GcUhN1XNxp<>zPXr)1{k>y~AfBpI6B;(*k9
zh87{3jJLQ7ic(985;OBsQ(pf6|Np;R5g(}1WGXTRS#AV!4^+LW!OOq@|Nqy7NWElW
zVPGh-0Lk;Drln;jXQt+r<fWEWsrZF?x_CP3g96wmHL)l!GcVn>A~ClhC$-oL+$6{-
zO^TAuE6YsDOpHg?8eg87l3JV^pJb2;N)NZ#iZemFUWza<FuVjM5Kujzom!NaniF4~
zSp{m8uoagSWu~OQ1i3Sc0~WcOEVo#}v3iTQ1l)Rv&jHoC#kV+0i&Nv1OH&eW2^OUm
zXCxM+#+QH+L|T4Pt|oI4s9jp*2=ccRC=A?hu@;vWq!txPf_%vVi|AV-#ffF9@fno`
zsYL~eMTxnoC8<Tlx43fh6I0^B))ujWG_ZhLMYlM?EfG+A=@w@mJmy&oiZb&`ZgGN5
zDo8BJC<3Px)`G;MlHyw|pk#N8B`ZHO?-pxOVqSV`kpaj=9&qqMT2jTgSRF%y{GD$J
z_<~zj&iOexsmUdo`FXcE9FvnvixQJ7Z*e*2XQbwNCl(YW-ePw%3<foiZm|WWCMV|P
z+~N!<N=?oz2I=4qE=Wzz1Vv&=<t-7oI7FdyX;E3~E#`nA=Ue>F`FUxjAT7R$B}JJP
zw*-=50`a*J5!U3K{Jhj#yk&_wnJMuwxAK6~E0h<-14_2>#Tlh(X*sF4*rALfIZ((O
zf<2mFS`wcIX^#~b34w(8(~9zQ;}MDVmLOC<J{8trD-s4N<N!xOd|F8nsBKbY3>Hhy
z$jnJ8O3k~)l$UplD>${VG&QdzGco5D7g7xJ7Nr)JW)`Iu$EO!1rrZ*AOi3&#$t+8C
zEK3i_&(ASRxy2HenOc5}%_lQ2HL>Uxi@U2&XptGn0#<O$++xlsO}fPbN(i^Oz?mG}
z1iZzRmVArNIX|xqWJFM^Pv|X9-^2=-nR$u1so<n{iw&Gijc&2z<rn1^fs%cZGstx2
z<iy-4!NighkfZbS;vwZpv8KQ+mg3Z$G_d9(4^ZfNf|RpnCl;j_-(o3A%`Lda4h?ij
zK(OcLr<CTT7Durbr4}1n+yW)I#9Xi+z(S@5MV1T<3{gU8ft#6^4$c@W8E<jMLsD^k
zJh+@KQUTcuDjxMf#rG3X8Og}U$ic|L$i=}1A{n_DMHmH`I2d`D#h7^*g_w94nf|jd
z@i20LX*Ncd|13;M<aZ8U0VXc692X-Sm_{OhvT+J92{7_8@-g!;vN3Wou`%*7@-Tu#
zco;>PSwQNUc^J7ESs0lZIT+a(c^KJ#bFuR=@i2*i#Qw9O%KqVC=VKIN<YDAu5(3FH
zvi;*=;bUZBWMSk3>xPhEbH20j@-Xr+NicIUR*B=N29iM;859U0466S?7!-@eQVa~B
z_Ao;Y<3h#-j46x@8Ectp7#1+qFf9ai!kGNP1-T}3Q7R}BIZN}33rkZ|s~~AZlewrG
zB+OWGixm_%#gM=US9hQoxW%59UsRr0l(G#J3*Z(ylK^9t7@8}giclO0Dp^4qz>c&5
zIg)V!1IT#`7#A|sGS)EEFg7#RGL<l8G1oAqFr_f}GSxDduw=2qc`P+dH7qsEY0SY4
znk*1+fNPPJjJKFmGIKTAZm~oCa*H)BCqJ>IC?Dh(w#1_J+{B7ojBdBsi}TY;auX{w
zSs~tJ&CDw(ExE;(lb@bhQhbXwGCm`*<Q7|Gd_hru61vw}ic3-pc7yy49(rKnW8`4s
zW8`D15=QeqNL?~0K0t8;!k}IYI|Bm)*z0{@ug5UgGS{-yvevLHU|7gd%TmLzfU$;o
zA!7`4En6*nEeANp*lJi7GSzaHFfU-KVOhvn%T>dW#ahE9$&kXx${@**!eq|C%wWR+
z7i9*Cmax@uG&9z4Wg}$SQ&?&^Q&_=}t(TRNp@y@Dy@sQPC53$sOE7~bhaaR+1XamJ
z@}P_c%H<&|nQw6v6y>MKCnXl$;>gd7Pc6t#&H#s2dTNm-3#2TV2vR=*6ws_W;LbKj
zQ5Q&P5{Lkq1CE8FZV(q-y)b9yl@x(esU|lx7LpQ6k~3~`f~ok-l!_uy7!^$inaLPm
z)DO}Lw*_qJEyhUn*x?1YE;94d<3YiH6ckOM0-b?{i;;s#fSHGri%E=8j!}fMN(L>u
zP&A?>SxyE91`b%f{b2z0Izc%qg{g(1h7r``1Lds}h6PM1%qU{aB@7D~Q&>>MSV|bO
zm}|g2>RzT=mJ$|FtTi*bFvR-RveqzUvDL6jGNiCeGSsrcc^r}qwd^%)DV(4_X$nIv
zM;=cJdkuRtV-1@$sF%Z$!nJ_2hP8&HhP{SuAq%KqTNskU4YC(uKDb{BGLtKXM~ne%
z){Pp5EN)O{0`=c}nHfPTh$n@=hG79ONOS?;LWYG*DFR?92ud&D*0zveGN@q;O3)w-
zDu6)w8Qh!(RU+x2<}^<&V+~^pgCqmAG0j*6jyX^PvXZGt2^5trAOa)+i9%4ZQl!ej
z!0-?xuExN?P-TiMhQJ9hJ~b}|TPv6>*MSl)#2xcN)?;zjBak@AS!9_HO2qih2UTT7
zAoIaxSy4O4I&c#OoDo4G{!1QQ$tzUXS}7zZCo5Ff>L@__Bns8F`qi}<E(4Xd;CR8^
z%_))vSp@d4BAf;0!XxAzNF3y{m!Jag7ArWTtYo~!T3nKto?68j0&36eCzd9M6io&-
zWck54{T6d_VhLDHayBH_2Y}4+0uiA64=EkIL7AHk)F4hQxy74YT2z#pR{||yL~`=e
z<I!7u;1Z)K2xLkp$TCiF*%2iQb`&hz$LE8G!HRFOf(r2BTb#Lx72vYG_!b+aiGGVQ
z8B*4PLiZMPVsbJhcHln0B?LDSp4#q%N)u3=FffTQiZF?Qngw9Y!N|tM!zjim#=^%a
zz*MD3Fx{e53!qR1#T2+&$Rny+C}CQ_46X`516eFJoLH&?h$y%!sNq<^x{#rU6J8at
zL5qDBaIw#dUhHeKpcl5F<g3XBi8643o)3xuL|M*~S(2Ko$pwy!qM0DMSs(&leizLK
zanVZfIUq4`fP)E8db`D625QJBf;tx90{lBD$oatK47dPi0T<s)RdQ%SjHV5x_y&b9
zy!h^61Qp*|kReWJ@eLa0>}9HD0vFe;pkfnVtb>YhrW7_5F&3~ob`&vIaIwys!jZz#
z%T&u+!cxNu3iD>h6pn1>qA8%FJcYA{VF7y$a}DbP4sgNBl>&m?pn^3IG!zILR^<VU
zaixGo!Mb=$IBS?{Seu#9WVmV=7I4?FfXu34T?moogB9W+TfoC{%!~{rpus~va1|2(
zu2lGI*d!TJ1i_UGLl%T51R5@@VFwL+PGBt3NMWeu%;PB$05>7nr5S2DO9WGd7YNm`
z)o_BU8E7T*BSi!h3J4p(l?-SoR5*nTJeb!}!;mGC!XXYCVihS7T>u)`6iE>kVOYor
znmq_+&=e!1I?-f=Bv5d=!Bt>$GBPlL3)Cue+=&IAx3SeE;1)8tB*rLN3D$T;^Wdo)
zmjk(B4kW|$1q4m!p}_Qo1Wo6InGUY4qS!$LM(`RS9h6@|IEn)#4jP!yWGb2pGIScK
zyk{;;%qdz5;(`iONK11O0|Ns_IDkq;%Ihx!xLfT(HHJ92dMa8DvVR4L02O1mm~&F|
zZn2f-WfqpEqSY})pf>w04p42Bl9^mm6c5t08bsi#(TY}rWD`LI$cMMsiW2iu@^f#o
z6sM-9K#R#J0Z`b08mX|_3X-PRfOM?|5$iz2dJq9J1QJXeKwNOL2NU2jQv|8-gQsT>
zMh1prP(91QC;_XjSOplBm_%5lScMo_{<E<Nu<(IvFD_W^#mvLV!&s$EL~cQ;zXTZ=
z7(hV;uD?{k<twOp49ej(ppFczgTn;w;4m*>PGMTe2%4~|Wv*eUVQywzz_O5`maT+g
z0V}vpV=n=9Z`h!vGDitp7O3&d3@%wYOE^Fc-*j*%XaOgvtF(X%ZXRd~hr5QmhAV}&
zm#LPggdvNk1k{-2S;&~e2G-A8!dJr!Y9NE?EPgPHAxof!7vicE_7tWRmR{x<rdqyQ
z{u<^Qen|$f>jV(4t6|6ztYMdANZ|l=vr?GhoiB(eCrGr0V}Z~@h8lLH8W&V<vedBF
z2&8b$VGCx^<n{}JlrNy#4P3}*GTq_<cSPe$QuB)Qi*B()+A76G>Y!K$^_3BA7ELxt
zasfB1!L3(N9bHro%3qu%`32eWpiwVPVQ^x<B~YAN0v=|HPt8kA%1KRuROOu@)odxL
zC5g!ykTzNm$fjNpAr2xyV*^FaAXXJfDNAl@V%{ybkW`TKiok6*R`9syEf$cKMWFOn
zv<Z~S1ksxUU@5p0A#MjZ0Kjes`vR2Squ4>tj!!Fz;!Q6~tSnAW%t-}Jciv(vF3rtN
zO}WJdvM9bJGbgq977v)6nU|88oLXG87vxNEt^gC@ticH$2n1y^b5MTJ02NA%T#Q_d
z9L#)-JWO1m^v}Y@B*MbSDgqunl49f$;bIhHlw+!rK+6Ykqfi<b($E=6P(Q!;5EH0z
z0ZRO!Y7yRu05vSg=o^9(Bdmwn$(X`{FtZahlEVV-qp^aT5}=;r0=61vP?G~ZR=`%n
z3hFP~FqE)^X6c$4Q@ErVYC#!`vxaE_YYO*5#uT0u?i4m~Qs=1UECG!vaDy5N6BuLp
zIvGlMQg}NUnwc0G+8Npz)0k5DKr_NMoS@!rEmu23J5w4{3O|w<Xa<a<hO36Vnc0OQ
z_EZdWEl(|PEnf*w4Q~xk2ZJQotQy`FZb^n({u=fYz7B?FMrno`erbkUff|7<erPX~
ziIJg%t3;rNubC0#n%J6J!5Y3AP-8`?MreWHLIzOAS|FUlnj)0Ky^y(9s0Ji1vXDW7
zp+;zd=t2+)=1DW82#YX)<QIsQh-XPGWULh`kz62EBeamQR=7rZfpiV~Y=#t(4h9fe
z!w)6}XEUUT&gB3VT@x4!{U$ILvJ@VuVOt;ro;%cuVX76W6)oYZ5t+@9B32_hn_(_9
zs0gVMg}JPTtwzL#VFF{}EkuaM2GxR^J+d`ok_;*0k_@%tH4Ir0o`fVrtwfD@3O9I0
z(J6(YRx*#LM6O1nnXyJ(nxR&*geOIEfqacvjbx2PjX0#)HG!#6C`AgC8W47X8(y+C
z3|R`GRFNW`A_Hm)mMAVzs$ok9jr~d0NM$L*Q-iEDLy8=Xw?Jhf0}n%rJV*z)hp#|J
zldOnvd#4aq4_6st%UbZN3tLkRqiI060Y<oq1S)^QEkrU*C)`9Tfq56)L;{!aq}!v(
zSOhB4A)QlP4aqW?O=OvmC^vDLUk)=riUV8-CYB~g@#p3jmw+ZnKuw~|vQ$u5G8BR8
zs48~wFol%@hHF60Inbm@u>@?^RiKu!gRzDYR6f9liXdJE7vZ1*2~DP3jCr>>3KB~|
zGk(x<Xi(V>9xG~xxr|D76`cT8BA`)2h%?|8b-*kN0j<s{NCXX26yM@Vt;j4ciO<gl
z*U&HjgEY&4My`0km2Glju3lC_dKI&RtpcLvE}~&mcO%T1py`<65^y8CiXR+lMXAN5
zIVF0@#bvh`6E&G`v4AG7Zn0&gCZ?noX{v(T322RC#4t-y0?2Eip}ry?5bH3AK(v0r
zEdl1L%z|6Y#U(|zSaY*eGK-3~f|?A>$pxjiSo2DA3o37Mf@ePq@{3b%G3RFH-C_m}
z$Fn6CmsA#{-eSrtxW!tUnO9;_1j^gDIMa&qa}$e-5-V@9faWc4v4X}JiXm+b@W?W_
zkqv5x-(o6Axy9*{Sd!=lnjgKzR+1kN@+TWOdWtori$IeZQGAe5p?FA06c?ekv~Mvd
zmL?a~f!qvfHGt<y!5%#Znxo<cH#T7DFiHrS3u%uOfkrfMF&5lnEQw-=DvV+WSqhqQ
zRR)>N0&2M2V$RG>zr~WBT3LLHtt>I8G_|-0OEdo#Xu()%P6;?AYBEET=1EY%gVQG{
zY2M-kM=iYMEC4lfKx1PJj7S4iOahDoEF7$Sj8cpO;Mp+|MltZ58E6(ufRTeyj8OqR
zBc&rCz$nBh0-DWW<YR<%3nUoDcm$Yu7^`IQmf0Bj57Zt4cL+eE?<Js-WyTtYET(2A
z(7<po!%9X!P39s{flvgRL<G+%YBGW22_gV40Kf#;qavVjz|8d0BBT|8#m7M21tk^+
zCO*u53pO1uK_;TCC<8eGJ~#v#oCFP9r!cfKr7(g?1`wIXEXlBdp$0UM#<Y;JhN*@b
z%w}H5xR9lmrG^DO^3Sr6X#q$V^Fk(&STKVos~@E3X9PtWxD782Y8x}(VuvmVxWxim
zSWt8gBnq91VuQB*88dD%mKVJS=>)kMoCqM3pP*1Z4e~i?+Mj{Zh*5;8N&zi$pynf%
z8=%AwD9AvXK_(R!fd(8IvKX=$vzTfbQy6O)%NUCEQkWJ%SO3*8E?@yo&w*Ngpm~lw
zmKw%drYflt))KY_>?y1@jI$Z$f<vX5sg^mMAy0&n0R&4pQrMaqn;0V*@|X-jGm+rV
zS`9-yBe;#h?gt(dyoK562e}U1=|`PR0L_m<CKJG60a{CvoLG=ql9&Tp?EtDOz~T@B
zlxDyc#C2GKSfz^=M4)MN$odIv?TsqX)OBi3v0jzFXCAnJ4GQAYoKyw3%$!sOzr@^B
zh5RA~n4B|ci49m(QIqQydj@!#zZjgVA@#REDEG4#7vyA?++t5pElDgXDZ0g(nOl%w
zRFadKbc+|l1UaZQC-oLr7^qlqEh@?{y2YEBS6q^qmz)XqsU`~~lYm+Ux44oMb5nET
zK~ws-Si$~-cpsc%z+vFdz`$@3lvqI35(6U_3kRbN6AL2~6B{GPe<ntz{~XX6?J7yI
ze)zIWJw*7R^mjlVeNYU6iu~dv=v-J0V;0jw#w=#gx>4}jH|83KES4<R1#F-$e<njM
zYYnJZAXm#)!&bvm#!#dMZhL|zb=jJkKz$V(hC(gSz$j>TnvtP~F^wUUp@ww{V;^X)
znz`zff+unyC?qO?YOfrH(&Eg#bcK?PRLDd~aY24w3P=PT{@{f&Tnb14RHZ0@)-dQn
zV@LtqddbgAEUMH~h|DilNKVXCNX#kDR{(d!6%z9*6+o*SP)sSwSI93*Eh;a{EJ?*;
zH@HMs$jn2rNDpoTxJdBR<h;e6nO9tpnp|>=3zQy;(m|Xe(8`8eEFde3Kttb<6a-o?
zUGx)_g4jw6QW8s2HQ6DF0My{P#g$)@ky;cF(gR%v0%5ad=7FmiR*;P0E#}0cboA8m
zmw|!dF(`F_)~+zHiZSvqvVhl8h%j<tPLQIcLzF@klqf(2DY!rZb?9rL#pObV8YWP2
zSi)2To|Y_O&H@#etP2@HNdue+IGPz@bP8iHb1iELXBHP|nJswyhj{__LeNwwa|&xV
zQ&CwBLl#d8n<PUP?*hIW=GhD>>~ooFK$DvM3qjMDY+0a%x1dfO$6O}J0$AuMGq)dT
zAzx-*N@@jYAzw~uZeDRn6h|;90*VrIK$3}hdHE%`Sdw!S3yPM2%3V;gT_vawUxcIp
zS#wmyqo14`52|?e3i8r3CoqB+eKV&NWNHc(#ey_;f(TFsDFV$H6+HlPuYicFAOhZO
zy2V_aSA2^-BQ>!kH?bfJlnmn2z;)a$_MH6m^i<F?uhfcy)S}GX)Vz{gT!|?uPz6P+
zK&gQp)GkWPOb55gU}*v|bhr$x22>5^CYD68rliFu=SH!lWu`-y(jZC+P$TLiC>?;t
zq8J!KZ2}c05mpWsIVOzZK+rHK$V~9MEa);UO{SvvAg_T4O`%&H@!%pUKK>S0JhViM
zkH5teA75CSm;)-T;^V;yy$H1Kya+s>G#8|OBZ$}xB6fj@;~?S<hyc$(f)ZB|s5~nI
zCEFs9n{V+#R|h4h=H$Ru2!XxT2+{^F0YR(jv_Z2-pg~j)W}bQmP7YQME)FRU2@XLH
zeGW^GFb-BOHVz>UKJdyr+^gh4#^GBf4_?F03aa;uazLZmpyo1YzzKvwY!C)Di;6)N
zTM0um18BtysIaSHh-U&-ibZch4FTpN7m(vYB|*_4kN~t<1Y#F~{RQ?R$Ti?lWMg7r
zC<YbD44_QQ3tBS<UnPoCbc5^$xd9ZA#h@Ht!<YqbnDsH#GSx5zGiWka#lp*}Ajpcv
zVsHruEe(^C^NYacYe_y>7(Di;07``^paKt64wq%7mMeh9@Dz$mib|6~TO1H&pC;2S
z7LWpnzaVRjz>7R?G3S>;8oH1o3hX&h>fnNVj*W-0iU;gD=&B-3wjxnb?Jfo)#6bi&
z$UwS^BtTqnFo6khsDaG{MGPy<Ob$jK7I0K)+TUU>t}HGB=g(W9&}RS{@C(G+0m?;e
z1&MhniN)aj9K~LonwXSdlvf0r1i1xj>1X6*CP9`|a^)tLfVQ%L1VG_^i@6}D<Q8*s
znb9rg08gJHP&`F(`&RnHwzYuUz)@_zm7u*Ukd>U=r6rj;#d=^8JogaAS(2HXs|T7#
zDFU@si#S0J15IldftD13yKc8Qz-t2a^74v6(=SoNFs@!&X&!i?CTRH~f?Wh2=?w(+
zY{bDzp$qqrlo#=UtO0GexFuLn304%J16o=PT4Q^QsUYW;P#JhZC!(2|T2#aj3NLQ3
zYKUwRD5-(^OGU{U5D`6a<y~9^nmD+{1MU=pRsf_HK^9+v7gk0Iz^nwVjsy(>Lslq?
zf?O<vE(ZxLJ@AH@BG6)oC>0z^z*fh@2G&7?+)>=gi3KI4xq6^BSrKSx;ugOT%$vU8
z1&rW=;T9W|14=`XWC~7PU;>o5ZgJS;f|c8WYK~$z(5fMDZ3<eQ2*RN070|jw5Dmg2
h%p8n7j6#ecl7|&cg4Px?2{3}z{qZmf)iZ!#EC3nAaHaqN

delta 6350
zcmaEsJ1K%Mk(ZZ?fq{WR?9h}X8;6N}GK_W;wZrOJa`<xvq68QjQW#QLa|CmRqJ+R~
zwjAMHkth)`n>|M~S1d{l%;v}u%@vOl&y|Rh$d!ze1dB1|NJUAdFr+Z%NaxB#$w0+q
zb7hzq7#VWqqU50>3Q-Cv3@MyBQn`vziV(5-9Hl7b6s{ER9F<(vC{?f^PmWrydXzet
z&6}f<s~M#UX7lA}<!VQ1=juf1fcgA6y19B$dW;P23@HLBf-MXwf~m?``pwKy2JQ?g
zLMg&63@O5?GFgVr%uz-ujKK_=A}>K9;HSxWizOhjBtw(&7CT!(WkF7U$>i&d3n#ZT
z`A?Q()@00?oM$9IIi1<jRy0Mdg&~T!oq>fRiZ7T!Q~Va6Z+=Q@j%Qv;YEfEZa%wUQ
z$V4dSV_;z5Wnf_7W?*0_PGO$>pV^|mh9!kjlA(qno->6hg}H^HgsX-ji=&y*g`t_T
zh9RCC#7kkxW+}Q*!jr`d=CP)*^)l5k#PgQ$EZ{5Q$>Il#vV%lRc(Me*JdR$*5}qtU
zFq;!9Bb34k=5zHjmGEQ<gL(BFDG;s*m<#5mvn^z-Wz7o#S(w6-f?-_>OA6ew6qXdI
zRVgefNEW5AK+P3S;e@c$SwW_j@MMXC9l_BH@}5|I4QrNo3U>+*m@5J1@j`i$H7qH7
zk_=f=H4O36Df~4I@iH|mDFWcol1&k;VThN5@Kc1cnI<q6%_)&zppYWG5M;1oi4vTx
zJb^Lq2vns=Hq!*ABB>IUEY%dz6tUiV##*)#wHmes>I)f~8B-Wi#9Nsp8EV-}G_o{n
z*!vi2Ichi-Xf0%LVTk3a<xG*N;f&Xwz*u;vL}P(w4M&N_0-Y4eg^VS-DN@ahB^nF#
zK+ejNTgX(Ru|R(z!$Q^?hIo+mZcyu`!PZYGQOPm@ThCZu%T=OQ!v(iq2F-fz8g2~h
znF^!8)^j6SFAKF^u7&|*{X!Oy#d#N?*2;seeN$qXWt3%{qL88pww$?^r>33<i}5?a
zK2d@gRl}2I5^oB&kO%53P~ub0D}frWoXs?WxhNq;qLvq=szf8pEJY<nHAO8&y_dN}
zV}bcXh7ye|ixid=mK2R%Ca}C}iY7$f28#mS6g1yJ)GXu#`6-1Vm_bv^?-p}Te)?oS
zQIpBh9I9**3=9lK5|i6Ff*6e_-{eRWQDk6XxW$^Am|T>v$yB5Sl2)7?%xS`SYjOvt
zH7^eX1A``0kse4~Yw`t7QEm|PB?AisLy_U+51c71%nS?+lY_Z5xHT<ru@;vWq!twk
zPrk-wz{on8om)YOBPTySz96wA;}&aCVqSV`k;ddm?ySl8xZ@<nK!$39?B~fZEs0M{
zO)M!bN-ZuDm>kNZAtwlu;K&7)5%FmyMW7O)NCzawl9O3nGI=UbETjJ9&pdY8ntZod
zic@paZn1(D7ukV~wFjwW%}y*zFTTZ|n_5zonOt0?GkG?zbi6(T149&RQEIWNL6H%N
zEd;eWKDDBxC^5MtGcSE5<1Nnk_~e|#;^O%DB5<I|fQ+?fU|>)Ml}YD71r;M7BL^b~
zBNr<hh-4OH=3*3LoUFp9%*Z&|gHLvHI$wExl{jiKrxzcenU`4-AD;{gIFKnI3@Ynb
z85kIxL1tfKU|^_WjA5>2u4SoZtzlWfu#lmarG|L{V+!L!Mlg@Lh9Qe7g(;h<$f<@Q
zi#d%Ulc9!r31c5)En5v+7E3K#9#;*+0_Ga#`h{Sj1*{;gwQMyES!^|IDa^f$wd^(Q
zDJ+tpY|Ea)D#=jGS;L;fCdp6>(#T%JQNyy3sg}EhV*zIk%R<Ilo*ITMt{NUmh7@*I
z21$k#4s!-(1{(&rC?`m?gu8~TnX!f^n+Yn*T~p6l!&SqQ!Zn8_m_d`<uZR&8irNee
z3@aInG(d60mXn{JSyFt9Ei%5KC_m{IYh-*zVo4D=ZGaMAkvJ#?vOr?B2&B6hq>VKv
zF)1}iljjy|T26jq$u0Kc!qU{#s?;J-dMXM9DG|&`EH0^!&&W(kNzIEdE=ep&y~SCY
z2UC~{G7Fqm5<sfi5{uGv6DuH&V$RGf(c~@407<YWC6**-+~NdN@tG+QB}FA5ImY-~
z9MHs5e2b?jwYW5=q&Pk$Gr6Rw5M+E2Llgr;5y-({_upcSgrp^Sg1W_9Qk0mPmzkFy
zUtE$}Fbx#_pghRH%*DvT$j8jZD8R_U#Kp+P%)<mqY8=cwoLo#|jB-p>GL!EMOKc7m
z3S-m-2Mfr|l}traV0TrZ*;6FTz`(E{<RrPtw}kUGz)k|`C~5`?f>~e!oXSCp4}iqw
zCKrlW*nzwXF%OsGLtw=q0u*H{8E-MCWadKRP<--JkvI)ekQ~?oNjMA4h1+)qBrZ2O
zTU2Fos-V0Pe-SuVn3EGrZm|}ZBqnDURf1GFfd~f>0ZKhZr6Aw378GUXl}tV=swX4@
zl5qhMWuRokS(GuERm{Q&oSHxh=@xThax%m(aO-ag!Q&vcAU`=HJ~c08@8o>3uzHYZ
z;JI=MQMqyfQw=jT%Q3^V97_#b3KKNTF+;N)DC>c;94PB?m9Q>g17|Z(BZj?(TaqD#
zjg_Gul-r;=i6ILj%mK=0HCzig7BbXu!}A$O4J4a!g0mSHG@BKG!U38NVaX625+Hwp
z^VLcgSUQEtgW`H6CpcV->Oo-$&PzoVAfv!J1{5ujr~pO4EtbrZ)ZCRk;H*+q4N_GD
zB0zHBEK*bp;?{wPdJuu;oCc5>I3~aZC^g*TFH6kHOi3)sRLKWr+th-qpj-h;*9=Tt
zOdO1Sj9iRdOg!MCmxFckCvgYHw8_R2sr5AsSu81>C9E|}HLT4{DO_O4-OI`dDyCRc
zcxo6Hu+=cvur6R<$gmJyo}-2VRMdgwYFMGNyuHj|y<8w&DSRpXAbtrbj(9-L&e#BO
zfy!IMCdrT@AX(2)%U;8f1>p)pqnR1p-eO8&sO8M#DdDT(XlATomu9HtEa6WPS|Cuv
zR>N7tQNzBFrJ0eDq3}|Qa0*ik1HuY$%NEp@6-?m*o0L?;kR_DDAr5K>3zY~j5CQW=
zL>Lw_a)O%j!3>(Bev{X$nCgN<8XT&)a_xIi;L9;EFjSdO{v|FoIbW7@vX@k#MiWRJ
zug;Gk6>^gmrOh;2K}@`gKY>($6=%1Bn0OU`0jU7FsR)!vqd3YEbK-OIi;F{w`ao4R
zJBXiHnykrGGy$Y?A}IeemnG&D^?<nG)D6n#MV*riMMNgci>lNUE_(PG85oMdc3Of8
z4N-8>Q#1)=;A9X1E_lEt3mb?FFJ5ke+cx0R-XA0}9YmlPP~cQQ1tc>SL<E8ekTY+w
z6(#1S<mZCgE~zP+T;LoOB>?gts3wTdFD)r3Eh#PnN8=2Tu9+ZW7KoS)B0z?KGjY)z
z5EmS(U;><pL|~a{GM_Au<U3Ff0yUl(7$ulQ7)6*w82MPG7!_D1o6Gv47IVueD&}ff
zQkZ&~z-<en3%Ytxk(R<venF?n1g^Cq<qbHzz!f66jsu4hC@DdLY91)4=7WfZAOaj-
zhF}sLYMdph#U<dHSd5W@p%_}`fNElJnFDF=aWGE~RnX-JHAp#NO*S{i$rBZPCjV4W
zt!KAks9^&YBQ6ZFS}{zu?6n*v95w8-8B#cEIA$}<Wd`N+8jb~=3mF!0Eo7)+t6{fc
zD6B%{`B=YNQ16qshEtLug%i>u%YyK@BpGVCYq(OlvzdzSr7+a;fXfu_X2u#WX@**!
z5{?v}=mq>WoHaZ(+%;U#LS;h=FOEV*poR(5afFs@f(wLd*wPsnGS>3e@GcOpVaO6m
z5s+qB$OtW01VQDBkzNtu;z1jhp{fig=V;1K_L35r{8HJNh_XTltP@nQgc2?*^uUTi
z@d%FKc2EJq0gl1M(&Q>;1zQDZ69lA$5k!zuDtN(6zr_MF6I{SlaY5R6ddbCQn%K(;
zc<E3S04keULG@#C5vTyT#hhGFdW$u$G`FC#z6iaX067m_LV!w<TTBHhw>VuAOA_6R
z5_3~;v6bY<Cl{9$IfBY3K1fl)2Pq-qAucK|LM<b1F(;NL7iEK-z>=L>S$vDFEHS4v
zwYUgUb|izW0~eH_!s8Yj#1Qm`5g*9o@$g2G7AQ|>fx0V<&@zNsfKh;jgI$18h*5}%
zj}Zhp7}=P37{wUHm?o!bMmm57b&5cpu_93GL6ZsWc8CBt+kpvi-V;gA&r8cpFD*)q
z&o3y+%+0JyEh-M2{9n^>vZnSF#@fmIv@O6{Rg>`+S8`%OW=UcWNDZzAWh^-UCaddI
z2!O1<#gYdKNtUeq%)H5Kbvz6~jr*cqpa9}ZPRvcsi3iKGf_rOd0a0Yjz`zi`SxvW(
z(Fc^&{4_alv1jHL7o;Ya+~P`4Edg~i3sRGdK;6SzEFfOdPEZJgEH2szV(kT4#8z66
zl30?e$qosH!;|z4M8E+6CP1N3bdZ68A#<{een35_FHpm<fT4zQ0pmi38m5IzwJaq}
zS<E#o3s@F1EMQ&8Py*`Cu{Sfq=oH3Y=2})zmu(>fBSQ&S7IzIRXs9fOrI&dD&q4-B
zYX#KV%i>$WU&B0`A%$%&Qw_)zfrX%!4s!~7Hd9eY3C9A#8d-)cAyAn$mkCq=lyEE%
zu3-bQz@rMBleG-=CdV3xyA`bjwI4Y2lXK%iiA7VS$QR^YaLaKmh_wzxtOgNlKm<J5
z-(oJvDY?Z4iSfyM4QzBk9=^q1oLH8cmYEJt>bF>vbMlMf;f1IGdnd~oO6h<@4NQQ8
z4wP}qK}l2qoJ0i}RTw!KMVRH7I6%GfBttbu{>f7fRqH{KRMZG^%v=z$2t+Ig5i3B%
z77(!wM1Ue493Mp>lW*}t#{-g6b8=u~0$|JYK#IXG5N2Rt0QKIALA4GCGfyFhI)?;@
zB!>ov2!}cc8;1}FUy<<S03#2^*vT`E%q+KpvNLm$703XPt8TFtXCxM+!g9%dkT}?M
zusvWaK+4(}C$kxEss~AHGJ})SE!NDug3^*(%=x7yXleHrds==`d16sY7bs{z*^z;b
zhp~z$EHkxS4=Sa}R&*O=*BuaX7es)pDS8ZI-2)L|`@sa*IbfIg#WFB3w89MKVB}#b
z0*PoE-(oJVEG`1Y_AREof}$fJ1NMQ4rQkeJkeHW}SX=~398v7WsfkJXMR`S+874QF
zDApHU1eM|3zLoy45}*j=;wU!XN}v4V;v!II>=t)vNoG#59+)fwwaB74OEPnF^+3(U
zB6g5PplZ2@6T||I(}SCHw>V1ji;^?+^74uzKz0bjxO!=&dC4W2`FV*sw<hbGa&z5c
zbSnZ4!xV8(PBGQ52aU|!;s&cO%Ph{!&jXJ{MF}B^=z+(bii`L`{^KbrO3W)x%P-1J
zEh;VoCB-7pNLCbga$-SAX|7&kNl8(W5J)AzPiA^X38?Syo1apelUf8yKDXGQ98g#m
zNif8KhU-8+0f#;);BRr*<bu`PfvU7(3kC)T@W=q9Kg7tv$iv9P%mX4Bc^E~QL3|;Q
TD3}Ej7hvLH5~^pYXGjGA2Ue+&

diff --git a/models/model_interface.py b/models/model_interface.py
index 60b5cc7..b3a561e 100755
--- a/models/model_interface.py
+++ b/models/model_interface.py
@@ -7,6 +7,8 @@ import pandas as pd
 import seaborn as sns
 from pathlib import Path
 from matplotlib import pyplot as plt
+import cv2
+from PIL import Image
 
 #---->
 from MyOptimizer import create_optimizer
@@ -20,7 +22,10 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torchmetrics
+from torchmetrics.functional import stat_scores
 from torch import optim as optim
+# from sklearn.metrics import roc_curve, auc, roc_curve_score
+
 
 #---->
 import pytorch_lightning as pl
@@ -29,6 +34,10 @@ from torchvision import models
 from torchvision.models import resnet
 from transformers import AutoFeatureExtractor, ViTModel
 
+from pytorch_grad_cam import GradCAM, EigenGradCAM
+from pytorch_grad_cam.utils.image import show_cam_on_image
+from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
+
 from captum.attr import LayerGradCam
 
 class ModelInterface(pl.LightningModule):
@@ -41,11 +50,20 @@ class ModelInterface(pl.LightningModule):
         self.loss = create_loss(loss)
         # self.asl = AsymmetricLossSingleLabel()
         # self.loss = LabelSmoothingCrossEntropy(smoothing=0.1)
-        
         # self.loss = 
+        # print(self.model)
+        
+        
+        # self.ecam = EigenGradCAM(model = self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform)
         self.optimizer = optimizer
         self.n_classes = model.n_classes
-        self.log_path = kargs['log']
+        print(self.n_classes)
+        self.save_path = kargs['log']
+        if Path(self.save_path).parts[3] == 'tcmr':
+            temp = list(Path(self.save_path).parts)
+            # print(temp)
+            temp[3] = 'tcmr_viral'
+            self.save_path = '/'.join(temp)
 
         #---->acc
         self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
@@ -53,6 +71,7 @@ class ModelInterface(pl.LightningModule):
         #---->Metrics
         if self.n_classes > 2: 
             self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted')
+            
             metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes,
                                                                            average='micro'),
                                                      torchmetrics.CohenKappa(num_classes = self.n_classes),
@@ -67,6 +86,7 @@ class ModelInterface(pl.LightningModule):
                                                                             
         else : 
             self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'weighted')
+
             metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2,
                                                                            average = 'micro'),
                                                      torchmetrics.CohenKappa(num_classes = 2),
@@ -76,6 +96,8 @@ class ModelInterface(pl.LightningModule):
                                                                          num_classes = 2),
                                                      torchmetrics.Precision(average = 'macro',
                                                                             num_classes = 2)])
+        self.PRC = torchmetrics.PrecisionRecallCurve(num_classes = self.n_classes)
+        # self.pr_curve = torchmetrics.BinnedPrecisionRecallCurve(num_classes = self.n_classes, thresholds=10)
         self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes)                                                                    
         self.valid_metrics = metrics.clone(prefix = 'val_')
         self.test_metrics = metrics.clone(prefix = 'test_')
@@ -146,28 +168,40 @@ class ModelInterface(pl.LightningModule):
                 nn.Linear(1024, self.out_features),
                 nn.ReLU(),
             )
+        # print(self.model_ft[0].features[-1])
+        # print(self.model_ft)
+        if model.name == 'TransMIL':
+            target_layers = [self.model.layer2.norm] # 32x32
+            # target_layers = [self.model_ft[0].features[-1]] # 32x32
+            self.cam = GradCAM(model=self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform) #, reshape_transform=self.reshape_transform
+            # self.cam_ft = GradCAM(model=self.model, target_layers = target_layers_ft, use_cuda=True) #, reshape_transform=self.reshape_transform
+        else:
+            target_layers = [self.model.attention_weights]
+            self.cam = GradCAM(model = self.model, target_layers = target_layers, use_cuda=True)
+
+    def forward(self, x):
+        
+        feats = self.model_ft(x).unsqueeze(0)
+        return self.model(feats)
+
+    def step(self, input):
+
+        input = input.squeeze(0).float()
+        logits = self(input) 
+
+        Y_hat = torch.argmax(logits, dim=1)
+        Y_prob = F.softmax(logits, dim=1)
+
+        return logits, Y_prob, Y_hat
 
     def training_step(self, batch, batch_idx):
         #---->inference
         
-        data, label, _ = batch
+
+        input, label, _= batch
         label = label.float()
-        data = data.squeeze(0).float()
-        # print(data)
-        # print(data.shape)
-        if self.backbone == 'dino':
-            features = self.model_ft(**data)
-            features = features.last_hidden_state
-        else:
-            features = self.model_ft(data)
-        features = features.unsqueeze(0)
-        # print(features.shape)
-        # features = features.squeeze()
-        results_dict = self.model(data=features) 
-        # results_dict = self.model(data=data, label=label)
-        logits = results_dict['logits']
-        Y_prob = results_dict['Y_prob']
-        Y_hat = results_dict['Y_hat']
+        
+        logits, Y_prob, Y_hat = self.step(input) 
 
         #---->loss
         loss = self.loss(logits, label)
@@ -183,6 +217,14 @@ class ModelInterface(pl.LightningModule):
             # Y = int(label[0])
         self.data[Y]["count"] += 1
         self.data[Y]["correct"] += (int(Y_hat) == Y)
+        self.log('loss', loss, prog_bar=True, on_epoch=True, logger=True)
+
+        if self.current_epoch % 10 == 0:
+
+            grid = torchvision.utils.make_grid(images)
+        # log input images 
+        # self.loggers[0].experiment.add_figure(f'{stage}/input', , self.current_epoch)
+
 
         return {'loss': loss, 'Y_prob': Y_prob, 'Y_hat': Y_hat, 'label': label} 
 
@@ -212,18 +254,10 @@ class ModelInterface(pl.LightningModule):
 
     def validation_step(self, batch, batch_idx):
 
-        data, label, _ = batch
-
+        input, label, _ = batch
         label = label.float()
-        data = data.squeeze(0).float()
-        features = self.model_ft(data)
-        features = features.unsqueeze(0)
-
-        results_dict = self.model(data=features)
-        logits = results_dict['logits']
-        Y_prob = results_dict['Y_prob']
-        Y_hat = results_dict['Y_hat']
-
+        
+        logits, Y_prob, Y_hat = self.step(input) 
 
         #---->acc log
         # Y = int(label[0][1])
@@ -237,18 +271,23 @@ class ModelInterface(pl.LightningModule):
 
     def validation_epoch_end(self, val_step_outputs):
         logits = torch.cat([x['logits'] for x in val_step_outputs], dim = 0)
-        # probs = torch.cat([x['Y_prob'] for x in val_step_outputs], dim = 0)
         probs = torch.cat([x['Y_prob'] for x in val_step_outputs])
         max_probs = torch.stack([x['Y_hat'] for x in val_step_outputs])
-        # target = torch.stack([x['label'] for x in val_step_outputs], dim = 0)
         target = torch.cat([x['label'] for x in val_step_outputs])
         target = torch.argmax(target, dim=1)
         #---->
         # logits = logits.long()
         # target = target.squeeze().long()
         # logits = logits.squeeze(0)
+        if len(target.unique()) != 1:
+            self.log('val_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True)
+        else:    
+            self.log('val_auc', 0.0, prog_bar=True, on_epoch=True, logger=True)
+
+        
+
         self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True)
-        self.log('val_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True)
+        
 
         # print(max_probs.squeeze(0).shape)
         # print(target.shape)
@@ -276,24 +315,61 @@ class ModelInterface(pl.LightningModule):
             random.seed(self.count*50)
 
     def test_step(self, batch, batch_idx):
-
-        data, label, _ = batch
+        torch.set_grad_enabled(True)
+        data, label, name = batch
         label = label.float()
+        # logits, Y_prob, Y_hat = self.step(data) 
+        # print(data.shape)
         data = data.squeeze(0).float()
-        features = self.model_ft(data)
-        features = features.unsqueeze(0)
+        logits = self(data).detach() 
+
+        Y = torch.argmax(label)
+        Y_hat = torch.argmax(logits, dim=1)
+        Y_prob = F.softmax(logits, dim = 1)
+        
+        #----> Get Topk tiles 
+
+        target = [ClassifierOutputTarget(Y)]
+
+        data_ft = self.model_ft(data).unsqueeze(0).float()
+        # data_ft = self.model_ft(data).unsqueeze(0).float()
+        # print(data_ft.shape)
+        # print(target)
+        grayscale_cam = self.cam(input_tensor=data_ft, targets=target)
+        # grayscale_ecam = self.ecam(input_tensor=data_ft, targets=target)
+
+        # print(grayscale_cam)
 
-        results_dict = self.model(data=features, label=label)
-        logits = results_dict['logits']
-        Y_prob = results_dict['Y_prob']
-        Y_hat = results_dict['Y_hat']
+        summed = torch.mean(torch.Tensor(grayscale_cam), dim=2)
+        print(summed)
+        print(summed.shape)
+        topk_tiles, topk_indices = torch.topk(summed.squeeze(0), 5, dim=0)
+        topk_data = data[topk_indices].detach()
+        
+        # target_ft = 
+        # grayscale_cam_ft = self.cam_ft(input_tensor=data, )
+        # for i in range(data.shape[0]):
+            
+            # vis_img = data[i, :, :, :].cpu().numpy()
+            # vis_img = np.transpose(vis_img, (1,2,0))
+            # print(vis_img.shape)
+            # cam_img = grayscale_cam.squeeze(0)
+        # cam_img = self.reshape_transform(grayscale_cam)
+
+        # print(cam_img.shape)
+            
+            # visualization = show_cam_on_image(vis_img, cam_img, use_rgb=True)
+            # visualization = ((visualization/visualization.max())*255.0).astype(np.uint8)
+            # print(visualization)
+        # cv2.imwrite(f'{test_path}/{Y}/{name}/gradcam.jpg', cam_img)
 
         #---->acc log
         Y = torch.argmax(label)
         self.data[Y]["count"] += 1
         self.data[Y]["correct"] += (Y_hat.item() == Y)
 
-        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label}
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': name, 'topk_data': topk_data} #
+        # return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': name} #, 'topk_data': topk_data
 
     def test_epoch_end(self, output_results):
         probs = torch.cat([x['Y_prob'] for x in output_results])
@@ -301,7 +377,8 @@ class ModelInterface(pl.LightningModule):
         # target = torch.stack([x['label'] for x in output_results], dim = 0)
         target = torch.cat([x['label'] for x in output_results])
         target = torch.argmax(target, dim=1)
-        
+        patients = [x['name'] for x in output_results]
+        topk_tiles = [x['topk_data'] for x in output_results]
         #---->
         auc = self.AUROC(probs, target.squeeze())
         metrics = self.test_metrics(max_probs.squeeze() , target)
@@ -312,9 +389,41 @@ class ModelInterface(pl.LightningModule):
 
         # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True)
 
-        # print(max_probs.squeeze(0).shape)
-        # print(target.shape)
-        # self.log_dict(metrics, logger = True)
+        #---->get highest scoring patients for each class
+        test_path = Path(self.save_path) / 'most_predictive'
+        topk, topk_indices = torch.topk(probs.squeeze(0), 5, dim=0)
+        for n in range(self.n_classes):
+            print('class: ', n)
+            topk_patients = [patients[i[n]] for i in topk_indices]
+            topk_patient_tiles = [topk_tiles[i[n]] for i in topk_indices]
+            for x, p, t in zip(topk, topk_patients, topk_patient_tiles):
+                print(p, x[n])
+                patient = p[0]
+                outpath = test_path / str(n) / patient 
+                outpath.mkdir(parents=True, exist_ok=True)
+                for i in range(len(t)):
+                    tile = t[i]
+                    tile = tile.cpu().numpy().transpose(1,2,0)
+                    tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255
+                    tile = tile.astype(np.uint8)
+                    img = Image.fromarray(tile)
+                    
+                    img.save(f'{test_path}/{n}/{patient}/{i}_gradcam.jpg')
+
+            
+            
+        #----->visualize top predictive tiles
+        
+        
+
+        
+                # img = img.squeeze(0).cpu().numpy()
+                # img = np.transpose(img, (1,2,0))
+                # # print(img)
+                # # print(grayscale_cam.shape)
+                # visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True)
+
+
         for keys, values in metrics.items():
             print(f'{keys} = {values}')
             metrics[keys] = values.cpu().numpy()
@@ -329,16 +438,35 @@ class ModelInterface(pl.LightningModule):
             print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
         self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
 
+        #---->plot auroc curve
+        # stats = stat_scores(probs, target, reduce='macro', num_classes=self.n_classes)
+        # fpr = {}
+        # tpr = {}
+        # for n in self.n_classes: 
+
+        # fpr, tpr, thresh = roc_curve(target.cpu().numpy(), probs.cpu().numpy())
+        #[tp, fp, tn, fn, tp+fn]
+
+
         self.log_confusion_matrix(probs, target, stage='test')
         #---->
         result = pd.DataFrame([metrics])
-        result.to_csv(self.log_path / 'result.csv')
+        result.to_csv(Path(self.save_path) / f'test_result.csv', mode='a', header=not Path(self.save_path).exists())
+
+        # with open(f'{self.save_path}/test_metrics.txt', 'a') as f:
+
+        #     f.write([metrics])
 
     def configure_optimizers(self):
         # optimizer_ft = optim.Adam(self.model_ft.parameters(), lr=self.optimizer.lr*0.1)
         optimizer = create_optimizer(self.optimizer, self.model)
         return optimizer     
 
+    def reshape_transform(self, tensor, h=32, w=32):
+        result = tensor[:, 1:, :].reshape(tensor.size(0), h, w, tensor.size(2))
+        result = result.transpose(2,3).transpose(1,2)
+        # print(result.shape)
+        return result
 
     def load_model(self):
         name = self.hparams.model.name
@@ -372,18 +500,33 @@ class ModelInterface(pl.LightningModule):
         args1.update(other_args)
         return Model(**args1)
 
+    def log_image(self, tensor, stage, name):
+        
+        tile = tile.cpu().numpy().transpose(1,2,0)
+        tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255
+        tile = tile.astype(np.uint8)
+        img = Image.fromarray(tile)
+        self.loggers[0].experiment.add_figure(f'{stage}/{name}', img, self.current_epoch)
+
+
     def log_confusion_matrix(self, max_probs, target, stage):
         confmat = self.confusion_matrix(max_probs.squeeze(), target)
+        print(confmat)
         df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
-        plt.figure()
+        # plt.figure()
         fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
         # plt.close(fig_)
-        # plt.savefig(f'{self.log_path}/cm_e{self.current_epoch}')
-        self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+        # plt.savefig(f'{self.save_path}/cm_e{self.current_epoch}')
+        
 
-        if stage == 'test':
-            plt.savefig(f'{self.log_path}/cm_test')
-        plt.close(fig_)
+        if stage == 'train':
+            # print(self.save_path)
+            # plt.savefig(f'{self.save_path}/cm_test')
+
+            self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+        else:
+            fig_.savefig(f'{self.save_path}/cm_test.png', dpi=400)
+        # plt.close(fig_)
         # self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch)
 
 class View(nn.Module):
diff --git a/test_visualize.py b/test_visualize.py
new file mode 100644
index 0000000..7ef56d3
--- /dev/null
+++ b/test_visualize.py
@@ -0,0 +1,148 @@
+import argparse
+from pathlib import Path
+import numpy as np
+import glob
+
+from sklearn.model_selection import KFold
+
+from datasets.data_interface import DataInterface, MILDataModule, CrossVal_MILDataModule
+from models.model_interface import ModelInterface
+import models.vision_transformer as vits
+from utils.utils import *
+
+# pytorch_lightning
+import pytorch_lightning as pl
+from pytorch_lightning import Trainer
+import torch
+from train_loop import KFoldLoop
+
+#--->Setting parameters
+def make_parse():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--stage', default='train', type=str)
+    parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str)
+    parser.add_argument('--version', default=0,type=int)
+    parser.add_argument('--epoch', default='0',type=str)
+    parser.add_argument('--gpus', default = 2, type=int)
+    parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str)
+    parser.add_argument('--fold', default = 0)
+    parser.add_argument('--bag_size', default = 1024, type=int)
+
+    args = parser.parse_args()
+    return args
+
+#---->main
+def main(cfg):
+
+    torch.set_num_threads(16)
+
+    #---->Initialize seed
+    pl.seed_everything(cfg.General.seed)
+
+    #---->load loggers
+    # cfg.load_loggers = load_loggers(cfg)
+
+    # print(cfg.load_loggers)
+    # save_path = Path(cfg.load_loggers[0].log_dir) 
+
+    #---->load callbacks
+    # cfg.callbacks = load_callbacks(cfg, save_path)
+
+    home = Path.cwd().parts[1]
+    DataInterface_dict = {
+                'data_root': cfg.Data.data_dir,
+                'label_path': cfg.Data.label_file,
+                'batch_size': cfg.Data.train_dataloader.batch_size,
+                'num_workers': cfg.Data.train_dataloader.num_workers,
+                'n_classes': cfg.Model.n_classes,
+                'backbone': cfg.Model.backbone,
+                'bag_size': cfg.Data.bag_size,
+                }
+
+    dm = MILDataModule(**DataInterface_dict)
+    
+
+    #---->Define Model
+    ModelInterface_dict = {'model': cfg.Model,
+                            'loss': cfg.Loss,
+                            'optimizer': cfg.Optimizer,
+                            'data': cfg.Data,
+                            'log': cfg.log_path,
+                            'backbone': cfg.Model.backbone,
+                            }
+    model = ModelInterface(**ModelInterface_dict)
+    
+    #---->Instantiate Trainer
+    trainer = Trainer(
+        num_sanity_val_steps=0, 
+        # logger=cfg.load_loggers,
+        # callbacks=cfg.callbacks,
+        max_epochs= cfg.General.epochs,
+        min_epochs = 200,
+        gpus=cfg.General.gpus,
+        # gpus = [0,2],
+        # strategy='ddp',
+        amp_backend='native',
+        # amp_level=cfg.General.amp_level,  
+        precision=cfg.General.precision,  
+        accumulate_grad_batches=cfg.General.grad_acc,
+        # fast_dev_run = True,
+        
+        # deterministic=True,
+        check_val_every_n_epoch=10,
+    )
+
+    #---->train or test
+    log_path = cfg.log_path
+    # print(log_path)
+    # log_path = Path('lightning_logs/2/checkpoints')
+    model_paths = list(log_path.glob('*.ckpt'))
+
+    if cfg.epoch == 'last':
+        model_paths = [str(model_path) for model_path in model_paths if f'last' in str(model_path)]
+    else:
+        model_paths = [str(model_path) for model_path in model_paths if f'epoch={cfg.epoch}' in str(model_path)]
+
+    # model_paths = [str(model_path) for model_path in model_paths if f'epoch={cfg.epoch}' in str(model_path)]
+    # model_paths = [f'lightning_logs/0/.ckpt']
+    # model_paths = [f'{log_path}/last.ckpt']
+    if not model_paths: 
+        print('No Checkpoints vailable!')
+    for path in model_paths:
+        # with open(f'{log_path}/test_metrics.txt', 'w') as f:
+        #     f.write(str(path) + '\n')
+        print(path)
+        new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg)
+        trainer.test(model=new_model, datamodule=dm)
+    
+    # Top 5 scoring patches for patient
+    # GradCam
+
+
+if __name__ == '__main__':
+
+    args = make_parse()
+    cfg = read_yaml(args.config)
+
+    #---->update
+    cfg.config = args.config
+    cfg.General.gpus = [args.gpus]
+    cfg.General.server = args.stage
+    cfg.Data.fold = args.fold
+    cfg.Loss.base_loss = args.loss
+    cfg.Data.bag_size = args.bag_size
+    cfg.version = args.version
+    cfg.epoch = args.epoch
+
+    log_path = Path(cfg.General.log_path) / str(Path(cfg.config).parent)
+    Path(cfg.General.log_path).mkdir(exist_ok=True, parents=True)
+    log_name =  f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
+    task = '_'.join(Path(cfg.config).name[:-5].split('_')[2:])
+    # task = Path(cfg.config).name[:-5].split('_')[2:][0]
+    cfg.log_path = log_path / f'{cfg.Model.name}' / task / log_name / 'lightning_logs' / f'version_{cfg.version}' / 'checkpoints'
+    
+    
+
+    #---->main
+    main(cfg)
+ 
\ No newline at end of file
diff --git a/train.py b/train.py
index 036d5ed..5e30394 100644
--- a/train.py
+++ b/train.py
@@ -5,7 +5,7 @@ import glob
 
 from sklearn.model_selection import KFold
 
-from datasets.data_interface import DataInterface, MILDataModule
+from datasets.data_interface import DataInterface, MILDataModule, CrossVal_MILDataModule
 from models.model_interface import ModelInterface
 import models.vision_transformer as vits
 from utils.utils import *
@@ -13,18 +13,23 @@ from utils.utils import *
 # pytorch_lightning
 import pytorch_lightning as pl
 from pytorch_lightning import Trainer
-from pytorch_lightning.loops import KFoldLoop
 import torch
+from train_loop import KFoldLoop
 
 #--->Setting parameters
 def make_parse():
     parser = argparse.ArgumentParser()
     parser.add_argument('--stage', default='train', type=str)
     parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str)
+    parser.add_argument('--version', default=2,type=int)
     parser.add_argument('--gpus', default = 2, type=int)
     parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str)
     parser.add_argument('--fold', default = 0)
     parser.add_argument('--bag_size', default = 1024, type=int)
+    parser.add_argument('--resume_training', action='store_true')
+    # parser.add_argument('--ckpt_path', default = , type=str)
+    
+
     args = parser.parse_args()
     return args
 
@@ -39,9 +44,10 @@ def main(cfg):
     #---->load loggers
     cfg.load_loggers = load_loggers(cfg)
     # print(cfg.load_loggers)
+    save_path = Path(cfg.load_loggers[0].log_dir) 
 
     #---->load callbacks
-    cfg.callbacks = load_callbacks(cfg)
+    cfg.callbacks = load_callbacks(cfg, save_path)
 
     #---->Define Data 
     # DataInterface_dict = {'train_batch_size': cfg.Data.train_dataloader.batch_size,
@@ -58,11 +64,12 @@ def main(cfg):
                 'batch_size': cfg.Data.train_dataloader.batch_size,
                 'num_workers': cfg.Data.train_dataloader.num_workers,
                 'n_classes': cfg.Model.n_classes,
-                'backbone': cfg.Model.backbone,
                 'bag_size': cfg.Data.bag_size,
                 }
 
-    dm = MILDataModule(**DataInterface_dict)
+    if cfg.Data.cross_val:
+        dm = CrossVal_MILDataModule(**DataInterface_dict)
+    else: dm = MILDataModule(**DataInterface_dict)
     
 
     #---->Define Model
@@ -82,9 +89,9 @@ def main(cfg):
         callbacks=cfg.callbacks,
         max_epochs= cfg.General.epochs,
         min_epochs = 200,
-        # gpus=cfg.General.gpus,
-        gpus = [2,3],
-        strategy='ddp',
+        gpus=cfg.General.gpus,
+        # gpus = [0,2],
+        # strategy='ddp',
         amp_backend='native',
         # amp_level=cfg.General.amp_level,  
         precision=cfg.General.precision,  
@@ -96,12 +103,31 @@ def main(cfg):
     )
 
     #---->train or test
+    if cfg.resume_training:
+        last_ckpt = log_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.version}' / 'last.ckpt'
+        trainer.fit(model = model, datamodule = dm, ckpt_path=last_ckpt)
+
     if cfg.General.server == 'train':
-        trainer.fit_loop = KFoldLoop(3, trainer.fit_loop, )
-        trainer.fit(model = model, datamodule = dm)
+
+        # k-fold cross validation loop
+        if cfg.Data.cross_val: 
+            internal_fit_loop = trainer.fit_loop
+            trainer.fit_loop = KFoldLoop(cfg.Data.nfold, export_path = cfg.log_path, **ModelInterface_dict)
+            trainer.fit_loop.connect(internal_fit_loop)
+            trainer.fit(model, dm)
+        else:
+            trainer.fit(model = model, datamodule = dm)
     else:
-        model_paths = list(cfg.log_path.glob('*.ckpt'))
+        log_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.version}' 
+
+        test_path = Path(log_path) / 'test'
+        for n in range(cfg.Model.n_classes):
+            n_output_path = test_path / str(n)
+            n_output_path.mkdir(parents=True, exist_ok=True)
+        # print(cfg.log_path)
+        model_paths = list(log_path.glob('*.ckpt'))
         model_paths = [str(model_path) for model_path in model_paths if 'epoch' in str(model_path)]
+        # model_paths = [f'{log_path}/epoch=279-val_loss=0.4009.ckpt']
         for path in model_paths:
             print(path)
             new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg)
@@ -120,6 +146,16 @@ if __name__ == '__main__':
     cfg.Data.fold = args.fold
     cfg.Loss.base_loss = args.loss
     cfg.Data.bag_size = args.bag_size
+    cfg.version = args.version
+
+    log_path = Path(cfg.General.log_path) / str(Path(cfg.config).parent)
+    Path(cfg.General.log_path).mkdir(exist_ok=True, parents=True)
+    log_name =  f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
+    task = '_'.join(Path(cfg.config).name[:-5].split('_')[2:])
+    # task = Path(cfg.config).name[:-5].split('_')[2:][0]
+    cfg.log_path = log_path / f'{cfg.Model.name}' / task / log_name 
+    
+    
     
 
     #---->main
diff --git a/train_loop.py b/train_loop.py
new file mode 100644
index 0000000..f923681
--- /dev/null
+++ b/train_loop.py
@@ -0,0 +1,212 @@
+from pytorch_lightning import LightningModule
+import torch
+import torch.nn.functional as F
+from torchmetrics.classification.accuracy import Accuracy
+import os.path as osp
+from abc import ABC, abstractmethod
+from copy import deepcopy
+from pytorch_lightning import LightningModule
+from pytorch_lightning.loops.base import Loop
+from pytorch_lightning.loops.fit_loop import FitLoop
+from pytorch_lightning.trainer.states import TrainerFn
+from datasets.data_interface import BaseKFoldDataModule
+from typing import Any, Dict, List, Optional, Type
+import torchmetrics
+import pandas as pd
+import matplotlib.pyplot as plt
+import seaborn as sns
+
+
+
+class EnsembleVotingModel(LightningModule):
+    def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str], n_classes, log_path) -> None:
+        super().__init__()
+        # Create `num_folds` models with their associated fold weights
+        self.n_classes = n_classes
+        self.log_path = log_path
+        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+        self.models = torch.nn.ModuleList([model_cls.load_from_checkpoint(p) for p in checkpoint_paths])
+        self.test_acc = Accuracy()
+        if self.n_classes > 2: 
+            self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted')
+            metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes,
+                                                                           average='micro'),
+                                                     torchmetrics.CohenKappa(num_classes = self.n_classes),
+                                                     torchmetrics.F1Score(num_classes = self.n_classes,
+                                                                     average = 'macro'),
+                                                     torchmetrics.Recall(average = 'macro',
+                                                                         num_classes = self.n_classes),
+                                                     torchmetrics.Precision(average = 'macro',
+                                                                            num_classes = self.n_classes),
+                                                     torchmetrics.Specificity(average = 'macro',
+                                                                            num_classes = self.n_classes)])
+                                                                            
+        else : 
+            self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'weighted')
+            metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2,
+                                                                           average = 'micro'),
+                                                     torchmetrics.CohenKappa(num_classes = 2),
+                                                     torchmetrics.F1Score(num_classes = 2,
+                                                                     average = 'macro'),
+                                                     torchmetrics.Recall(average = 'macro',
+                                                                         num_classes = 2),
+                                                     torchmetrics.Precision(average = 'macro',
+                                                                            num_classes = 2)])
+        self.test_metrics = metrics.clone(prefix = 'test_')
+        self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes)
+
+    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
+        # Compute the averaged predictions over the `num_folds` models.
+        # print(batch[0].shape)
+        input, label, _ = batch
+        label = label.float()
+        input = input.squeeze(0).float()
+
+            
+        logits = torch.stack([m(input) for m in self.models]).mean(0)
+        Y_hat = torch.argmax(logits, dim=1)
+        Y_prob = F.softmax(logits, dim = 1)
+        # #---->acc log
+        Y = torch.argmax(label)
+        self.data[Y]["count"] += 1
+        self.data[Y]["correct"] += (Y_hat.item() == Y)
+
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label}
+
+    def test_epoch_end(self, output_results):
+        probs = torch.cat([x['Y_prob'] for x in output_results])
+        max_probs = torch.stack([x['Y_hat'] for x in output_results])
+        # target = torch.stack([x['label'] for x in output_results], dim = 0)
+        target = torch.cat([x['label'] for x in output_results])
+        target = torch.argmax(target, dim=1)
+        
+        #---->
+        auc = self.AUROC(probs, target.squeeze())
+        metrics = self.test_metrics(max_probs.squeeze() , target)
+
+
+        # metrics = self.test_metrics(max_probs.squeeze() , torch.argmax(target.squeeze(), dim=1))
+        metrics['test_auc'] = auc
+
+        # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True)
+
+        # print(max_probs.squeeze(0).shape)
+        # print(target.shape)
+        # self.log_dict(metrics, logger = True)
+        for keys, values in metrics.items():
+            print(f'{keys} = {values}')
+            metrics[keys] = values.cpu().numpy()
+        #---->acc log
+        for c in range(self.n_classes):
+            count = self.data[c]["count"]
+            correct = self.data[c]["correct"]
+            if count == 0: 
+                acc = None
+            else:
+                acc = float(correct) / count
+            print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
+        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+
+        self.log_confusion_matrix(probs, target, stage='test')
+        #---->
+        result = pd.DataFrame([metrics])
+        result.to_csv(self.log_path / 'result.csv')
+
+
+    def log_confusion_matrix(self, max_probs, target, stage):
+            confmat = self.confusion_matrix(max_probs.squeeze(), target)
+            df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
+            plt.figure()
+            fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
+            # plt.close(fig_)
+            # plt.savefig(f'{self.log_path}/cm_e{self.current_epoch}')
+            self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+
+            if stage == 'test':
+                plt.savefig(f'{self.log_path}/cm_test')
+            plt.close(fig_)
+
+class KFoldLoop(Loop):
+    def __init__(self, num_folds: int, export_path: str, **kargs) -> None:
+        super().__init__()
+        self.num_folds = num_folds
+        self.current_fold: int = 0
+        self.export_path = export_path
+        self.n_classes = kargs["model"].n_classes
+        self.log_path = kargs["log"]
+
+    @property
+    def done(self) -> bool:
+        return self.current_fold >= self.num_folds
+
+    def connect(self, fit_loop: FitLoop) -> None:
+        self.fit_loop = fit_loop
+
+    def reset(self) -> None:
+        """Nothing to reset in this loop."""
+
+    def on_run_start(self, *args: Any, **kwargs: Any) -> None:
+        """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the
+        model."""
+        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
+        self.trainer.datamodule.setup_folds(self.num_folds)
+        self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict())
+
+    def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
+        """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance."""
+        print(f"STARTING FOLD {self.current_fold}")
+        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
+        self.trainer.datamodule.setup_fold_index(self.current_fold)
+
+    def advance(self, *args: Any, **kwargs: Any) -> None:
+        """Used to the run a fitting and testing on the current hold."""
+        self._reset_fitting()  # requires to reset the tracking stage.
+        self.fit_loop.run()
+
+        self._reset_testing()  # requires to reset the tracking stage.
+        self.trainer.test_loop.run()
+        self.current_fold += 1  # increment fold tracking number.
+
+    def on_advance_end(self) -> None:
+        """Used to save the weights of the current fold and reset the LightningModule and its optimizers."""
+        self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt"))
+        # restore the original weights + optimizers and schedulers.
+        self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict)
+        self.trainer.strategy.setup_optimizers(self.trainer)
+        self.replace(fit_loop=FitLoop)
+
+    def on_run_end(self) -> None:
+        """Used to compute the performance of the ensemble model on the test set."""
+        checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)]
+        voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths, n_classes=self.n_classes, log_path=self.log_path)
+        voting_model.trainer = self.trainer
+        # This requires to connect the new model and move it the right device.
+        self.trainer.strategy.connect(voting_model)
+        self.trainer.strategy.model_to_device()
+        self.trainer.test_loop.run()
+
+    def on_save_checkpoint(self) -> Dict[str, int]:
+        return {"current_fold": self.current_fold}
+
+    def on_load_checkpoint(self, state_dict: Dict) -> None:
+        self.current_fold = state_dict["current_fold"]
+
+    def _reset_fitting(self) -> None:
+        self.trainer.reset_train_dataloader()
+        self.trainer.reset_val_dataloader()
+        self.trainer.state.fn = TrainerFn.FITTING
+        self.trainer.training = True
+
+    def _reset_testing(self) -> None:
+        self.trainer.reset_test_dataloader()
+        self.trainer.state.fn = TrainerFn.TESTING
+        self.trainer.testing = True
+
+    def __getattr__(self, key) -> Any:
+        # requires to be overridden as attributes of the wrapped loop are being accessed.
+        if key not in self.__dict__:
+            return getattr(self.fit_loop, key)
+        return self.__dict__[key]
+
+    def __setstate__(self, state: Dict[str, Any]) -> None:
+        self.__dict__.update(state)
\ No newline at end of file
diff --git a/utils/__pycache__/utils.cpython-39.pyc b/utils/__pycache__/utils.cpython-39.pyc
index 704aade8ba3e8bcbedf237b3fde5db80b84fcd32..33d3d005a578948dd3d5b58938a815f86c4e92f5 100644
GIT binary patch
delta 2565
zcmew*wN#!jk(ZZ?fq{Xc{>JR2w;~hyWa=+5GBBhtq%h_%<T6GvGJ@DlIZV0CQOvn4
zQ7m9Sa}H}1YYIaOa}HZBdlUy)j3tLNmn(`3%x2Bu&gF^X1+&?5_;UH9_`z)U9D!WH
zC_ylrBS$D#I7&EIBuWI#=gbk!6^jyMtOr@hl_Q=j5hal;86}x36(t2$z?~zWD;*^b
zX7l98<jO|Lg4w(|a=G$R@?bV!jzX?tlp>hTpQDtk9HpGA5~Y%>8l{@67NwS}9;FVJ
zW6sft(nw)Q5y;U5$ulyfsHF&|2(>UqY1OANqzLDT=W0jkFfycwq=>dKMCqn5q=@C{
z<?2W2=Nd#AfQ=B(G0ZiJG6J(Da*T6LqD&YWQY2HPS{R~C!C?anBePs{CI&_ZXoyIs
z$h0s-S)?$e$U-4=irgHw6lRc~6y}tA*%XKXNP&EcLJBh)UlGb@PLWN)F0TYOP<akp
zlx2!aifRi(lvN6IFoUN0OHfQ}GTvedNG!?FWV*%d=;VA$I5DZXq$n}DBsnLsxHz{y
zwIm}y#ZQy*7OS^geol%e<1MbV(!Au7%>2B>98JbsJVl9lDfzka#RWN;B_LUs#F9jx
z{KS;hB2A`S>>yroYRN6O;L@bxRFDD=$K>SFqQvA%P3Bv|KAGtmC3%^7=^%B!`6;D2
zskiuxQ&UsoQ_E6|DoZlzGxO4Kv4<2TX6B_9X)@m8hwFf-207m+Kfgee@fN#VW(kzZ
z3Dw}1r^$MY*)gy37K=+}a>*?gpUmQtTO9rc5a-`w38^ed)nqIZVPIfLW&{NX6f-k0
zFmNz1FgQ<+W2|M9Vqjn>5}EvvQHoJ<@*hUGdNBqDhGI65T1Gxb0mdR_1_lO@3v%+)
z(^HFzF^mE!bq47vV_;yYVaQ@gVa#SO(kfxhVya=tVoqU7VNPM`Wv*o`VO+oh5?RPt
z%amtQ!U|$HGo~;~Gt@A|GsDz7rZ9rFu=cWmv}A#`q%fwi^)l5mm9V9-gS5?Nn9J16
z$jDH`Si_Xg6wIK>;dhIrB(=CC#7&cF@;hdE6Xsj&WvNBQnfZA|Y9Jq~gMCq?!N9;!
z#iy&Qt83?zpRSOSS!AW4$$pEqxFj(>b@Fr;1$B;`{Pg&O#FC6#oX){vK44E5F)=VO
z++r;#%FHX#WWL3moR&WM4U2@UBFJPgp#)OElarsA5)Tb0Sx}J3fr62fi;07Y<sTah
zAF}`>2O|?Bh|R&s^q-AIfVs$EawKbDJ;(x(S`Y>WC=V!pKqhQqU|>jRsAZ}F2UIO{
z4RaPl3KJ-VdYKp*N*GgEQa}O81P*An6!u=$T9z871uQiz3mF+1N?1$SQaDmLQ#eyN
zQn<j7yO%wMr<bvo6|9OCq^gFYK8q=fJ%u-$sb~X8xP&8xrG&GFxtTGA4{RBK3R?<O
zia;+1SUJ-ImJ*H>L8uBLs0v|Z6~PReB7V2HlQUA2vkUSw^Gb?CG?{L3q*i1Wm&E61
z-(oLFEK1EQDZa%~mY5TtGr5<|no)A{Nw!^*R-lk%PsuC-#{oxLW=?8eVs2`Y?c_P^
z(u|Umcd$DIa@}Ih%}&WIDl!2pc2CVqElSKOvH|hf5(^4a^HPe8KyhOX7Aa0GDoZUY
zG6Zp$i%W_$*^2Z*Y|i4uvebBxJ(F!Y3=Ew>I>50BCcsh32aeL@#GIU@#N_N^a5@iT
zU|?WiVqxN#JdZ<1or!~yjgbife{yhOiWEgo{>brI3}G&pO-^QUNpgN}fnDX~SDZy6
zASq3zTdXDdMadbrm~-;ei()4iaEa)EB#J<8DM|#TU631-Km;g9YqH&9ElSKwPrb#Q
zlbTnQJb43`R9pf`HjROSA&RvElo)TZ<m49@7lF*tWG({Broxnf4ZbCuT$Ep29G{w3
zQj}j%84q?s1t{@>N(lx=DMl4WF~*`?kS4Fm8r*BRL1uvsEK-_$pIfHB1Ed(7gqezJ
z!5Pg@(;5`VB^f!HNs!#boS2kc1S)!pSU?6PgB&OZwyG=>luh+YiW2jR)AEaQi*Iou
zu^B^di5F+*q$U>S>E-69q~^pIr{<)B%jhCjP%Ltk)PurIFE6hMlyi!rL4pEcll4kV
zGINUcQW8rNi`YTJ(&)lqIZzy@78QYfdrJbVisICgB2JLAK@EsoDquCasU=03$;Eof
zIrWLf#hGcD$%&wnS1%D-1r%|EY!oc01gnnEfz<{@puAng3o?xlMDT-@$zUkc%gN6#
zDAr3#EKV&F04Y(xr6jFBvm_p*x=0X*x)MkgqgPy#Sdv;?Bm`2%R#I7znU`K93~~dg
zjwliZu|TDNkvNDY1BxNG#FUiG<PvZtP^1D9)&da*AQPlv4w~G-Bg_uYUPVTet9dNg
zz}Ze~as!Xd<f}aVY@l3P6gl}lk4z0D6;*-E03{SiB0@?|Aaglva>13Q9Vm$xyD=~@
y@Gyc3P98=W6kz0G5`pt%m^m1E7`fPZ7&(*#K%z_>#vF`7j2s*SjBxmijS~P`KaFPq

delta 1941
zcmZ1~|4WK5k(ZZ?fq{YH?7=BX?LrgzWa^n185mL+QW$d>av7r-89{8O9HuCy6owS$
z9Ohh>C>BPB6v-5p6xJ5TDAp8)6t*0;T=pmqMursj6pj{#D9#jy6wVy3T<$3DT%IT%
zunAl_yt#Z)d|)<r4u7sdlmH_`3Qr1e3qzD(3PTE0j!={kn64Mj6=7mvWXKhb5@TdY
z;Y;CfVTck>VMq~xLZ%eKIczB)=P;!(r3j=z_#k<q6yX$*7)Tz*7lF!yRAQGG1)C=}
zhb>AXMLb19p@ktzGKDFaK~w4_C}1=hZ?Oa<mSp&W7%rK~C6oD>Co8fsFfbG`GcYh{
zGT&lND@ZKKxW$r|nUi{pxiU9rB|{PK<U`DEf*@7JAcBFBk5PcJh<~yqi*!9mfSrMX
z!I^=9q1b?tfuV*Wiy@1#ma&AXhOwD3i#dfcg)xPxmx+;~gr$V3h9Qf!nK6s4nK6aA
zh9QeRi(@v!Tqbb_Fpo2f3&N{qf~(|CVMt+VWs+p51*zvrVa;YPItMYkhDn;CzJ^(v
zA%(4mA&WPQ52RU|p@t!zv4lTMpoX!8rG}}QQJkTM8CgXNBPeQnS!-EJ__G9Ys^aWr
zu4T>hLRba1i51QKdYA>Qc#LID0UK5#l)_cRki`gcW-}urLk&v}YdR=yxcx#_GTq`x
zt;j4ciO<iz#a@tDl$uvke2X#u<?sLh|KDOMNi8n9#gdkvlj5eya*HJ|F*o%Vds%8x
zF>_{q-sBT3^7Twb5};U=WME*Z;?vdD)wT1<Pgh9EEV5G26u!moo|>0hl$djiBPTyS
z9u&hxObiSRw^(ztQ!<NgF&CE<-C|45&r8cpzr_Zz{T8cleoAW2Esmtb<m{yUywqDP
zKKaGPw>Xm$i&Nus@{5bXHr`?_E=f#J<-Wz@l30>hB+kIVFj<;ap`Oz@ILs$MJw3JP
z7HdIKW?qRV+b!ngwDclbkiB4MfP4|f1M)>?eqKD7$61mTp99tv#g$xK24xrVg1iPM
z_&~aOa`F>XpmN0^;bJ{d;%4Gt6k^n26k*|FR8VH*W8`CG`Nzh>$H@1egINs3gYsGa
zv#|&;7wJvj&l*_I2+|3~tPBhcJg~I9je&t7ouQTyl%8vuYM8P>=@^!hnNpZjSZWxu
zn6p??7*kkN*m_xOnQIsqu+}gyWMpJ0VJl$=C9)Kb6!sKOFy!iGPvP!mtYrbKVgacF
z>5t9gNa4w5D%u1RF5ygJE&(MGhGxbTUU0_YO93S;{$38S2F3-fC7dY&a5aKZH9{$j
zDR4Ex44T4zAw`0qm|$mjVq~4{#lB0@6cljmDVar}$iKw_%HE(5FS3{{%OTCkI@yxL
z!Ic{vIYn}y&^7`QG9bbnq?0YNpdd9brN{szVh9!~PAw`+Eh^FniO7Qr57r_bke#PE
z3=C~Sf}ogzL>)NR_`tE2oS2gXN>RnepaLTd6!J_gOl*_2Id#+-IT$$@nGo<N2M4A|
zQNZLn&c|Y4!@=fq+2mvvmn7%s7T9G>F6Sx=0ZC~x-C`}tFG|k1#hjC$UK9dyAJZ+)
ziumIEw36J!id&482&WZAFfcG=f}9))B0xz(lkFC3QDR<t>MiD+)V!jo$+p~5#^E5@
zSOx}$DAo#4^1H<X&IVxnn2W%&F)$@XV7CcRZsZb{NC$b91LRdkDMl4WF~*`4kTTcF
zTe;V8gLFfkGC7Aw#=IP)7?e2~n2Jh3&Svz}lmZnAB^f!HN#JZ-1Trg%Eiok}Gr0s@
za2N4Up35uFEdVlB14QUeF6UEZ1ILZQ<avA+jE0l1^BHi1V>JMzAZju%zfANkuC&s;
z<dV$%yu_TMAdooNWrZLX$a09okempLc@CT0{FKt1R69^W6@#qfVdh}uVdP@tVdPK}
QVB}%sVd5~FJds}n0E2$Q-2eap

diff --git a/utils/utils.py b/utils/utils.py
index 010eaab..814cf1d 100755
--- a/utils/utils.py
+++ b/utils/utils.py
@@ -10,10 +10,11 @@ from torch.utils.data.dataset import Dataset, Subset
 from torchmetrics.classification.accuracy import Accuracy
 
 from pytorch_lightning import LightningDataModule, seed_everything, Trainer
-from pytorch_lightning.core.module import LightningModule
+from pytorch_lightning import LightningModule
 from pytorch_lightning.loops.base import Loop
 from pytorch_lightning.loops.fit_loop import FitLoop
 from pytorch_lightning.trainer.states import TrainerFn
+from typing import Any, Dict, List, Optional, Type
 
 #---->read yaml
 import yaml
@@ -27,30 +28,30 @@ def read_yaml(fpath=None):
 from pytorch_lightning import loggers as pl_loggers
 def load_loggers(cfg):
 
-    log_path = cfg.General.log_path
-    Path(log_path).mkdir(exist_ok=True, parents=True)
-    log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
-    version_name = Path(cfg.config).name[:-5]
+    # log_path = cfg.General.log_path
+    # Path(log_path).mkdir(exist_ok=True, parents=True)
+    # log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
+    # version_name = Path(cfg.config).name[:-5]
     
     
     #---->TensorBoard
     if cfg.stage != 'test':
-        cfg.log_path = Path(log_path) / log_name / version_name / f'fold{cfg.Data.fold}'
-        tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name),
-                                                name = version_name, version = f'fold{cfg.Data.fold}',
+        
+        tb_logger = pl_loggers.TensorBoardLogger(cfg.log_path,
+                                                  # version = f'fold{cfg.Data.fold}'
                                                 log_graph = True, default_hp_metric = False)
         #---->CSV
-        csv_logger = pl_loggers.CSVLogger(log_path+str(log_name),
-                                        name = version_name, version = f'fold{cfg.Data.fold}', )
+        csv_logger = pl_loggers.CSVLogger(cfg.log_path,
+                                        ) # version = f'fold{cfg.Data.fold}', 
     else:  
-        cfg.log_path = Path(log_path) / log_name / version_name / f'test'
-        tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name),
-                                                name = version_name, version = f'test',
+        cfg.log_path = Path(cfg.log_path) / f'test'
+        tb_logger = pl_loggers.TensorBoardLogger(cfg.log_path,
+                                                version = f'test',
                                                 log_graph = True, default_hp_metric = False)
         #---->CSV
-        csv_logger = pl_loggers.CSVLogger(log_path+str(log_name),
-                                        name = version_name, version = f'test', )
-                                    
+        csv_logger = pl_loggers.CSVLogger(cfg.log_path,
+                                        version = f'test', )
+                              
     
     print(f'---->Log dir: {cfg.log_path}')
 
@@ -63,11 +64,11 @@ from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
 from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
 from pytorch_lightning.callbacks.early_stopping import EarlyStopping
 
-def load_callbacks(cfg):
+def load_callbacks(cfg, save_path):
 
     Mycallbacks = []
     # Make output path
-    output_path = cfg.log_path
+    output_path = save_path / 'checkpoints' 
     output_path.mkdir(exist_ok=True, parents=True)
 
     early_stop_callback = EarlyStopping(
@@ -94,8 +95,9 @@ def load_callbacks(cfg):
     Mycallbacks.append(progress_bar)
 
     if cfg.General.server == 'train' :
+        # save_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.resume_version}' / last.ckpt
         Mycallbacks.append(ModelCheckpoint(monitor = 'val_loss',
-                                         dirpath = str(cfg.log_path),
+                                         dirpath = str(output_path),
                                          filename = '{epoch:02d}-{val_loss:.4f}',
                                          verbose = True,
                                          save_last = True,
@@ -103,7 +105,7 @@ def load_callbacks(cfg):
                                          mode = 'min',
                                          save_weights_only = True))
         Mycallbacks.append(ModelCheckpoint(monitor = 'val_auc',
-                                         dirpath = str(cfg.log_path),
+                                         dirpath = str(output_path),
                                          filename = '{epoch:02d}-{val_auc:.4f}',
                                          verbose = True,
                                          save_last = True,
@@ -136,87 +138,3 @@ def convert_labels_for_task(task, label):
     return label_map[task][label]
 
 
-#-----> KFOLD LOOP
-
-class KFoldLoop(Loop):
-    def __init__(self, num_folds: int, export_path: str) -> None:
-        super().__init__()
-        self.num_folds = num_folds
-        self.current_fold: int = 0
-        self.export_path = export_path
-
-    @property
-    def done(self) -> bool:
-        return self.current_fold >= self.num_folds
-
-    def connect(self, fit_loop: FitLoop) -> None:
-        self.fit_loop = fit_loop
-
-    def reset(self) -> None:
-        """Nothing to reset in this loop."""
-
-    def on_run_start(self, *args: Any, **kwargs: Any) -> None:
-        """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the
-        model."""
-        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
-        self.trainer.datamodule.setup_folds(self.num_folds)
-        self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict())
-
-    def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
-        """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance."""
-        print(f"STARTING FOLD {self.current_fold}")
-        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
-        self.trainer.datamodule.setup_fold_index(self.current_fold)
-
-    def advance(self, *args: Any, **kwargs: Any) -> None:
-        """Used to the run a fitting and testing on the current hold."""
-        self._reset_fitting()  # requires to reset the tracking stage.
-        self.fit_loop.run()
-
-        self._reset_testing()  # requires to reset the tracking stage.
-        self.trainer.test_loop.run()
-        self.current_fold += 1  # increment fold tracking number.
-
-    def on_advance_end(self) -> None:
-        """Used to save the weights of the current fold and reset the LightningModule and its optimizers."""
-        self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt"))
-        # restore the original weights + optimizers and schedulers.
-        self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict)
-        self.trainer.strategy.setup_optimizers(self.trainer)
-        self.replace(fit_loop=FitLoop)
-
-    def on_run_end(self) -> None:
-        """Used to compute the performance of the ensemble model on the test set."""
-        checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)]
-        voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths)
-        voting_model.trainer = self.trainer
-        # This requires to connect the new model and move it the right device.
-        self.trainer.strategy.connect(voting_model)
-        self.trainer.strategy.model_to_device()
-        self.trainer.test_loop.run()
-
-    def on_save_checkpoint(self) -> Dict[str, int]:
-        return {"current_fold": self.current_fold}
-
-    def on_load_checkpoint(self, state_dict: Dict) -> None:
-        self.current_fold = state_dict["current_fold"]
-
-    def _reset_fitting(self) -> None:
-        self.trainer.reset_train_dataloader()
-        self.trainer.reset_val_dataloader()
-        self.trainer.state.fn = TrainerFn.FITTING
-        self.trainer.training = True
-
-    def _reset_testing(self) -> None:
-        self.trainer.reset_test_dataloader()
-        self.trainer.state.fn = TrainerFn.TESTING
-        self.trainer.testing = True
-
-    def __getattr__(self, key) -> Any:
-        # requires to be overridden as attributes of the wrapped loop are being accessed.
-        if key not in self.__dict__:
-            return getattr(self.fit_loop, key)
-        return self.__dict__[key]
-
-    def __setstate__(self, state: Dict[str, Any]) -> None:
-        self.__dict__.update(state)
\ No newline at end of file
-- 
GitLab