From d7fe105eceef23555df224fd257187c17ba64430 Mon Sep 17 00:00:00 2001
From: Ycblue <yuchialan@gmail.com>
Date: Thu, 23 Jun 2022 10:39:02 +0200
Subject: [PATCH] added jpg_dataloader, rm fixed_sized_bag
---
.gitignore | 4 +-
DeepGraft/AttMIL_resnet18_debug.yaml | 51 ++
DeepGraft/AttMIL_simple_no_other.yaml | 49 ++
DeepGraft/AttMIL_simple_no_viral.yaml | 48 ++
DeepGraft/AttMIL_simple_tcmr_viral.yaml | 49 ++
DeepGraft/TransMIL_efficientnet_no_other.yaml | 4 +-
DeepGraft/TransMIL_efficientnet_no_viral.yaml | 6 +-
.../TransMIL_efficientnet_tcmr_viral.yaml | 8 +-
...ebug.yaml => TransMIL_resnet18_debug.yaml} | 7 +-
DeepGraft/TransMIL_resnet50_tcmr_viral.yaml | 7 +-
__pycache__/train_loop.cpython-39.pyc | Bin 0 -> 9235 bytes
.../custom_dataloader.cpython-39.pyc | Bin 15301 -> 15363 bytes
.../custom_jpg_dataloader.cpython-39.pyc | Bin 0 -> 11037 bytes
.../__pycache__/data_interface.cpython-39.pyc | Bin 7011 -> 10339 bytes
datasets/custom_dataloader.py | 77 +--
datasets/custom_jpg_dataloader.py | 459 ++++++++++++++++++
datasets/data_interface.py | 107 +++-
models/AttMIL.py | 79 +++
models/TransMIL.py | 40 +-
models/__pycache__/AttMIL.cpython-39.pyc | Bin 0 -> 1551 bytes
models/__pycache__/TransMIL.cpython-39.pyc | Bin 3328 -> 3233 bytes
.../model_interface.cpython-39.pyc | Bin 11282 -> 14054 bytes
models/model_interface.py | 249 ++++++++--
test_visualize.py | 148 ++++++
train.py | 58 ++-
train_loop.py | 212 ++++++++
utils/__pycache__/utils.cpython-39.pyc | Bin 3450 -> 4005 bytes
utils/utils.py | 126 +----
28 files changed, 1547 insertions(+), 241 deletions(-)
create mode 100644 DeepGraft/AttMIL_resnet18_debug.yaml
create mode 100644 DeepGraft/AttMIL_simple_no_other.yaml
create mode 100644 DeepGraft/AttMIL_simple_no_viral.yaml
create mode 100644 DeepGraft/AttMIL_simple_tcmr_viral.yaml
rename DeepGraft/{TransMIL_debug.yaml => TransMIL_resnet18_debug.yaml} (91%)
create mode 100644 __pycache__/train_loop.cpython-39.pyc
create mode 100644 datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc
create mode 100644 datasets/custom_jpg_dataloader.py
create mode 100644 models/AttMIL.py
create mode 100644 models/__pycache__/AttMIL.cpython-39.pyc
create mode 100644 test_visualize.py
create mode 100644 train_loop.py
diff --git a/.gitignore b/.gitignore
index 4cf8dd1..c9e4fda 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,3 @@
-logs/*
\ No newline at end of file
+logs/*
+lightning_logs/*
+test/*
\ No newline at end of file
diff --git a/DeepGraft/AttMIL_resnet18_debug.yaml b/DeepGraft/AttMIL_resnet18_debug.yaml
new file mode 100644
index 0000000..03ebd7e
--- /dev/null
+++ b/DeepGraft/AttMIL_resnet18_debug.yaml
@@ -0,0 +1,51 @@
+General:
+ comment:
+ seed: 2021
+ fp16: True
+ amp_level: O2
+ precision: 16
+ multi_gpu_mode: dp
+ gpus: [1]
+ epochs: &epoch 1
+ grad_acc: 2
+ frozen_bn: False
+ patience: 2
+ server: test #train #test
+ log_path: logs/
+
+Data:
+ dataset_name: custom
+ data_shuffle: False
+ data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+ label_file: '/home/ylan/DeepGraft/training_tables/split_debug.json'
+ fold: 0
+ nfold: 2
+ cross_val: False
+
+ train_dataloader:
+ batch_size: 1
+ num_workers: 8
+
+ test_dataloader:
+ batch_size: 1
+ num_workers: 8
+
+
+
+Model:
+ name: AttMIL
+ n_classes: 2
+ backbone: simple
+
+
+Optimizer:
+ opt: lookahead_radam
+ lr: 0.0002
+ opt_eps: null
+ opt_betas: null
+ momentum: null
+ weight_decay: 0.00001
+
+Loss:
+ base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/AttMIL_simple_no_other.yaml b/DeepGraft/AttMIL_simple_no_other.yaml
new file mode 100644
index 0000000..ae90a80
--- /dev/null
+++ b/DeepGraft/AttMIL_simple_no_other.yaml
@@ -0,0 +1,49 @@
+General:
+ comment:
+ seed: 2021
+ fp16: True
+ amp_level: O2
+ precision: 16
+ multi_gpu_mode: dp
+ gpus: [3]
+ epochs: &epoch 200
+ grad_acc: 2
+ frozen_bn: False
+ patience: 20
+ server: train #train #test
+ log_path: logs/
+
+Data:
+ dataset_name: custom
+ data_shuffle: False
+ data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+ label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_other.json'
+ fold: 1
+ nfold: 3
+ cross_val: False
+
+ train_dataloader:
+ batch_size: 1
+ num_workers: 8
+
+ test_dataloader:
+ batch_size: 1
+ num_workers: 8
+
+Model:
+ name: AttMIL
+ n_classes: 5
+ backbone: simple
+
+
+Optimizer:
+ opt: lookahead_radam
+ lr: 0.0002
+ opt_eps: null
+ opt_betas: null
+ momentum: null
+ weight_decay: 0.00001
+
+Loss:
+ base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/AttMIL_simple_no_viral.yaml b/DeepGraft/AttMIL_simple_no_viral.yaml
new file mode 100644
index 0000000..37ee074
--- /dev/null
+++ b/DeepGraft/AttMIL_simple_no_viral.yaml
@@ -0,0 +1,48 @@
+General:
+ comment:
+ seed: 2021
+ fp16: True
+ amp_level: O2
+ precision: 16
+ multi_gpu_mode: dp
+ gpus: [3]
+ epochs: &epoch 500
+ grad_acc: 2
+ frozen_bn: False
+ patience: 50
+ server: test #train #test
+ log_path: logs/
+
+Data:
+ dataset_name: custom
+ data_shuffle: False
+ data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+ label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json'
+ fold: 1
+ nfold: 4
+
+ train_dataloader:
+ batch_size: 1
+ num_workers: 8
+
+ test_dataloader:
+ batch_size: 1
+ num_workers: 8
+
+Model:
+ name: AttMIL
+ n_classes: 4
+ backbone: simple
+
+
+Optimizer:
+ opt: lookahead_radam
+ lr: 0.0002
+ opt_eps: null
+ opt_betas: null
+ momentum: null
+ weight_decay: 0.00001
+
+Loss:
+ base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/AttMIL_simple_tcmr_viral.yaml b/DeepGraft/AttMIL_simple_tcmr_viral.yaml
new file mode 100644
index 0000000..c982d3a
--- /dev/null
+++ b/DeepGraft/AttMIL_simple_tcmr_viral.yaml
@@ -0,0 +1,49 @@
+General:
+ comment:
+ seed: 2021
+ fp16: True
+ amp_level: O2
+ precision: 16
+ multi_gpu_mode: dp
+ gpus: [3]
+ epochs: &epoch 300
+ grad_acc: 2
+ frozen_bn: False
+ patience: 20
+ server: train #train #test
+ log_path: logs/
+
+Data:
+ dataset_name: custom
+ data_shuffle: False
+ data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+ label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+ fold: 1
+ nfold: 3
+ cross_val: False
+
+ train_dataloader:
+ batch_size: 1
+ num_workers: 8
+
+ test_dataloader:
+ batch_size: 1
+ num_workers: 8
+
+Model:
+ name: AttMIL
+ n_classes: 2
+ backbone: simple
+
+
+Optimizer:
+ opt: lookahead_radam
+ lr: 0.0002
+ opt_eps: null
+ opt_betas: null
+ momentum: null
+ weight_decay: 0.00001
+
+Loss:
+ base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_efficientnet_no_other.yaml b/DeepGraft/TransMIL_efficientnet_no_other.yaml
index 79d8ea8..7687a0c 100644
--- a/DeepGraft/TransMIL_efficientnet_no_other.yaml
+++ b/DeepGraft/TransMIL_efficientnet_no_other.yaml
@@ -6,10 +6,10 @@ General:
precision: 16
multi_gpu_mode: dp
gpus: [0]
- epochs: &epoch 1000
+ epochs: &epoch 200
grad_acc: 2
frozen_bn: False
- patience: 200
+ patience: 20
server: test #train #test
log_path: logs/
diff --git a/DeepGraft/TransMIL_efficientnet_no_viral.yaml b/DeepGraft/TransMIL_efficientnet_no_viral.yaml
index 8780060..98fe377 100644
--- a/DeepGraft/TransMIL_efficientnet_no_viral.yaml
+++ b/DeepGraft/TransMIL_efficientnet_no_viral.yaml
@@ -5,11 +5,11 @@ General:
amp_level: O2
precision: 16
multi_gpu_mode: dp
- gpus: [3]
- epochs: &epoch 500
+ gpus: [0, 2]
+ epochs: &epoch 200
grad_acc: 2
frozen_bn: False
- patience: 200
+ patience: 20
server: test #train #test
log_path: logs/
diff --git a/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml
index f69b5bf..5223032 100644
--- a/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml
+++ b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml
@@ -5,7 +5,7 @@ General:
amp_level: O2
precision: 16
multi_gpu_mode: dp
- gpus: [3]
+ gpus: [0]
epochs: &epoch 500
grad_acc: 2
frozen_bn: False
@@ -16,10 +16,11 @@ General:
Data:
dataset_name: custom
data_shuffle: False
- data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+ data_dir: '/home/ylan/data/DeepGraft/256_256um/'
label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
fold: 1
- nfold: 4
+ nfold: 3
+ cross_val: True
train_dataloader:
batch_size: 1
@@ -33,6 +34,7 @@ Model:
name: TransMIL
n_classes: 2
backbone: efficientnet
+ in_features: 512
Optimizer:
diff --git a/DeepGraft/TransMIL_debug.yaml b/DeepGraft/TransMIL_resnet18_debug.yaml
similarity index 91%
rename from DeepGraft/TransMIL_debug.yaml
rename to DeepGraft/TransMIL_resnet18_debug.yaml
index d83ce0d..29bfa1e 100644
--- a/DeepGraft/TransMIL_debug.yaml
+++ b/DeepGraft/TransMIL_resnet18_debug.yaml
@@ -6,10 +6,10 @@ General:
precision: 16
multi_gpu_mode: dp
gpus: [1]
- epochs: &epoch 200
+ epochs: &epoch 1
grad_acc: 2
frozen_bn: False
- patience: 200
+ patience: 2
server: test #train #test
log_path: logs/
@@ -19,7 +19,8 @@ Data:
data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
label_file: '/home/ylan/DeepGraft/training_tables/split_debug.json'
fold: 0
- nfold: 4
+ nfold: 2
+ cross_val: True
train_dataloader:
batch_size: 1
diff --git a/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
index a756616..f6e4697 100644
--- a/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+++ b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
@@ -5,8 +5,8 @@ General:
amp_level: O2
precision: 16
multi_gpu_mode: dp
- gpus: [3]
- epochs: &epoch 500
+ gpus: [0]
+ epochs: &epoch 200
grad_acc: 2
frozen_bn: False
patience: 50
@@ -19,7 +19,8 @@ Data:
data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
label_file: '/home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json'
fold: 1
- nfold: 4
+ nfold: 3
+ cross_val: True
train_dataloader:
batch_size: 1
diff --git a/__pycache__/train_loop.cpython-39.pyc b/__pycache__/train_loop.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65c5ca839312787cce6aed19f346ae40a7c04d83
GIT binary patch
literal 9235
zcmYe~<>g{vU|=x&IVEYL1_Q%m5C<8vFfcGUFfcF_|6*WZNMT4}%wfo7jACR2v6*t1
zqL@+`QkZg>b6J=e7#VU|qu9VQ%sK4298nx#Hd78~E>{#cn9Y*Ilgk^$o68r)2j;Wp
z@aGCd34qziCbQ)T<_bj#fyLQ#gmXorM8Is09MN2{C^0abGe<mEB1!_x=E{-Gm5P$e
zm5!3mm5Gwcm5q|km5Y)C%Q5B1N6CX-sSu?Irj>G)qg0@Bs!?k03@O|xJS_|<JgJP$
z%u(v@3@N-Rd@T$qe5pds%uyOCOu-DA{4YVFqRDuR-zPIYqa-ggFWomkr8FniPm}Q$
zhhuVbX;ETwr6$uYW=ALITYQO0#U(|F$tAg|B^miCASEfOsRhaT1(lkNw^)4g^9wW?
zZ?U^&mOz=DAw`Lqd8tKid76y3gq;$LQ@!2tb5dLqOA;a0XtLg7cFe21#p05gTyl%W
zC$qTZ7KeX9NoIatV$Lm=kjjG8WRMGyF&mUq9KyiBkjfCnn8Fanl**jSBFT`-n#v~0
zkiwY4G>5sJA&oJGIfbQ#CyG6VH<crmGnF%iF@<doOA31mM+<8dR|;neR|`WFHzb^*
zc-k3Q7@~NC88rEB3A^SMr{*T*q=w~}K*A(7$4`^_7H4jLN@`Aga!&Crf#i(T<m`g{
z%)FBLg2a-H;#+J*sU@XFdC9C$2ZAV&W4Rd^7??qc&4H1Dp@gBCVFBYp21bS&<{E~0
zrW%$ShInR>lUo={SZWxun3@?~7@8St7~)w=*s|Cca6mo6)WT821d>?@lH)AlYG$Zm
zi020BNMXrlDLPfclf?_>v8J%~GSx7|^Oo=|;49(D;sT4ZgG5VsviQL~j$XzRo-6?{
zn-eM{n8FF>bM-Qn@MH<0@PrY(bhd?zwXAs|Ap25SQZQ^w0R=D{*bM#@mK3N>DJ&^S
z_N1_+z-@uB(^)~L*05%YfTE)p<VVpGo-8r2i#Wi1@f3z&22EbSWJXZPLNOBq0|O|)
zgFyi-#=yXk&QQY;%U8>o0?NybbC@PF6|w{~tYj(z>C$An#hRR7npbj*JvqOqC^flc
zCBsUlTTFTew-_^v#2FYEz{D?8{fzwFRQ<}F#611-{G#mQg2d!h{g9%>ykcKZA6=K!
z)B^XS#IzFq5>Nt<&&kg(&?~6C#buL|SzMBwpIcxj&A`C$8RW()11xDtFFrmqFS8^*
zUe6{cKRGd{*iH|j6|4V*(f!xKuz;b4aUmnbe_&^8GTsu&$xlp)Pb<pLjYp(UO{OAA
z1_p*(j0Gzhiv(ewECM;PNQQxd0fdWW85kH|f<p5aM|mo!Fi1_&WV*$jSDG82oRe5w
zoLYQ~J+UmcC^0?t7He*1a#6k}W05k*Sk~M`5Ra)y1!Octs3f(xBwmy87F$74YFcK6
zro=7Q;?jcDqFWp=pWoulgB#3|lb;UGM7LOr67$kiZ?U8#mL%R{Ey*uR&bY;tmv@T`
zQeJ=y2sUsgEG`1&sVEMxO^M0Lw|Ky+b5l!-GLwsMu{wqZ`8(ed@CEao^K)`ilR?GG
zEiUK$jMO~u#DapvTkLLz!O8hWskhjIQj-&Na&B=36s0C-7H8(?-Qo@|NKMX6%S_HJ
zsl3ISoRgoIdW#2a8O#cP=ls01QjlWb#FC=SidzE7FoF17u!ttxEtcZcoHR(F3xR@1
z8${@Uf{h(&&Pv8xobiyLh>tH)0Lg=6MvH-g0hFM@<s1Ve7ZV#J8>1Mb3?m047o!j(
z4<plm7A77>4lvEe$nu|s35op9!7ISX1(xGtWCPPk<WDwEKE^5;d?^!>Us!IjCMA|6
zXWZfhQ}LN86}R|6kpoI+sYM`u5h!UUgKAVzatC2h-3!W-pa?2n!N9;!!x+Qd$xzEu
z%UZ*-fMFp+ElUT(0>%`^g^VCxCqoI-0!Vqyl)~J?QNxhM0x9Sgu!7XovXwApvDL7F
z3Y%W0TJ{q5EDku2ql9SzTMf%X##&I}#aY8C$&kXv${@**!eY+A%wWS%!;l3LWe17Y
zurJ_R$WX(X%><QYtKq0&uVG2yn8Omxpvmb6_BKWi1o;l+&0-N)3Byy%Si{)C(2Sh_
zAej)B*>YDh-eS$nD<~}iizBj|6DW#7*{#YDTLgg9SaC^e0k*Q?C8!iz$#{!7B{R23
z6l4%6iD<IiV#~=-&nzjv#TFS~P?Vo^i#0MnBeCQbYffTPYK|uNE!MQ0{KS%5?8Swp
zsi{?|Mf#v5$XZ;In4Dc?1Y)z~rY7dyVoNMa&rPhj#prg6y*NLuBsa05NDnNXS(2Ko
z$yH<zveOwvxPS;(5aAC}z!+cT1LF9B2w$*BB*f$HAQ4cKzQqamn-!=y0|h1nGZ!NV
zBOkK@BOfCl6BnZxqa0I}EZ%TQ29;8vbOyqF3=9k$3=9k)Cl$MaD_Bs<N?~eYr~y}`
z3mI#fN*ER}r7)w2F_(Z+APb5ZO9?|3a}Bggt7WNSSin-lRKv1>bs@t-rdrk-h6QXj
z%phJ3%R;7Fwi?zNwiI?7h8ng7><bxO7-F?zm}=Q;IZ8Nc*k?1OaMW<jW|+&|%*e=4
z!?A#KAp<z&vDL8KFcemmfGTbtP*or62Tr-X;FQZL$xsWbt|2@wNrqbP8m<)XY^I`n
zDGaqdc|0Y2HQddNHC)mRwLB#pDLf1KYdC9oYPf5-7P5eBD%_C53-S}fd~j0&)G!bL
z`6Y!fg};}Xk)cFzflv)wI^#mdTHYGo1;RB9St2O{(hLj1ZQx)AO+n0R7L*CV)hsC9
zz||}`6QnRmGC(RJO~xWn#DimfB`h0MfVmI?lnB5z4p<pD3z*@I>ePb#<c#>#ycBHt
z0Ba5<<T!BB1XuFWpd<-qfeE;~z$!rQBFlVGK7e=>m-zu8dqL(Gd4gD8plkxI!Ag^>
zm=$al5SbcW5>?5AD_@1`S}TRb<Ya~FS{(&w%SEBOR=>Iy!=a#p7aS0vf&m;$7>%JK
zPy|98Dgkl^I1Io9JRovF;vk3KVgcn@O~zZGvLZdTimNEKxHP9kFS)o(QxsB)FefLL
z6a|4afLi25CLmTIh_C^9j}=sK6yIVkD9X$$xy77ZP<o3suQa!yvPc)C23#M4N|+*0
z*xX`E%P-1JECID+(&N#q(_2ghDYrO5y$83V#N5<dY$f^e$;D+wpz2hU4_q|e;>#~B
zDJU(8hd2{#Z+=qoEzaD;3UI+#e2c9lu_!&Y<Q8*cX>t*$6L^ayJGHX-7F$_jPHAfK
zEyiR-9Nc0CwJg{mhCxCM9!|ITKp_(kPo|*KuNYKmGcbxUiZF>V@-YiA3b1gn3or^X
z3Ni743OGg%Mm8oMMlnV)CIKcM#wsPksUM|gWoLl2cEP3HG*Hb7ZUurm6g5l>nQB=|
zm_Q}i0+xjgpmM5&t%f0sy_pe4r!e+1*Rqy?OCv^x60R)n8rBq0$;i^nyntsRLoG->
zZwjj<Ll)lx{u<`l3@L1LnQB0y0t=ZzU6vH~Y^I`)5{?CeH4Iropi*uw6R6B6;aDJC
z!v<o38|j>Wpq6%KUP@{OsHL4#nwwV~a*HFlAT_z9C@}{lnV6TCUvi5jIXAHYlF|kB
z;q55}Nb{*Ej)8%pibFp+Hy#w!nj%GZpd<rMTj?NH28aN45Q|bkEO>&t#axh6a*Hi3
zGrhDZ^%iq+UhysVjMT)E+{A)gT<NJL@lYA|oc#3k)S}{BT&WcWsYRK&sd**0xDr!R
zU<!DXON)w9^GYDGe2cv}u`D$$Gaa1bz>UJ<R84kp!Ycyjm|T!E;z7xt9h7)MDVa4T
zEj~H-7E4-YdOTVRDoSGjbqE<4iUmMb8KVHB3L^)j2(ug$2cr;Ul?JxN0ct6Oj0U$m
zVa-fNa0qJhfm+vjiMgrq@wd3*<8#41`S|!-Jn`{`rHMHZnIe9W8$nfQ5iiK0%*7=|
zMWAF`1nQ^WV$RGfDT)FqDFG267vAE9HcyjNb8_O5J>dva0xmN^9so5*ia~WT2cw*d
zm=K3X5l9piJs{t~F{l~=af$^PQTy+pPB2Qpoi&v$l|6-Z4r2;i3VRDnGh-A7xDU@U
zhY8$wXNlrUVFb&AMmkucxKnvjc~Tftc;_&K>bVw{C|)EHffT_OmMA_X5up^}7M3V}
zs2L(DqAe^@0#Ke<ig*i4lwhh*DoZMJ8j~ag*dB=#$rhF<;Z&g%#uTYJOexYSGA%4o
zB2YcDDRM0=QKC?ue2PK~OO#kD4_K#Sic$+plsMQt9<X_gDav!0QdCk@TUeqbz+(@R
z!3>&ew>ZILB%l!#@Q4Gb1D2MblTv((JGG)9zX&qg0BLlC8rYzq1YuCHf?^dEi^XQ(
zgasdI0QIJsQa~hg4QOP5rG`0$Ns=Lp6*S6F!<@n_2^vda17}?pNE7%LE4Y_+i#aDh
zU6Z{?6BN_nRMHPhG(50;0diT<1dt@SAP1MkERYNVQV5p1#hRU1lwMo}PA%}_6P)In
zK+QlNkn<Tqy%bOnhljCB2p$6PezGPb#K(*vwNMOVgM17w3_*s}Fl2$6`;6JF!3--I
z{WO`1CWBlCa>z<Xh@-&Pg9%Wwy~UD}pO@MNas<dA2Bs=Lgw-G^O{QBMX_+P94nD*I
zV0BOjfY=}mb^xf;Qv(WO22d~-O=VzUSjkv4gMooTlL=xsh>cbPfC|UFywv29KDZkg
zs{|2lNX`e#qEzxAZ-X${R#2>`Fa|SdGF8d><(Fh+=A|o?<SP`V7N?ddWacT9WMmdA
zfQKjiAXbAa4G0Sn3ant|6Oo+2k8lE5PLuf-3n&h6v1ONoXh>LqTmcOV@BkMmIfB|i
z#TMY?SHqCSn8j4W+{{!g<i*g$2;naP4RW!9Y6@_z#0VPVS_m2N;)MF7JhV791>_A-
zpIjlKIJKm-ARd&`iW3w-LmLVu8L0{hScVf56f*OQOA_;vQxy{PQWT0y@{3Zzit~#y
z(=+oDa}>(KLmR~k`Dq|=E(It6r*%C)O|Dy9nZ+<OZn1+$cT$UPae;cgkm9rm6wkM~
zk=#@SN?%26pll(K0~$C1jiN#e-&<V8C5a`e@hPC;O=)!5c!cCiW=I;D14^)SK?FE@
zzyvs@^5o~m7nSCLv=o(00Yw$aMh0dfMghhuF+_ABD?>>(pvVDXa0Efh2T+NS#W<T`
zE)%HhP{LHhkj0$EQo`EIRICJzFy<_l1)%Dm5fowIA!Ft$4_wg`53cZujw8J)zTgnY
zpb$?#cLg_p9~T8bP4=QBkl(@Cb{>c|A4Dtw5#VqaK!iKQvZAFRIdCw632;~m<mbgF
zrj&uw0VKR;!^4XS)EKLhLJ2PvrMSZkl&EVMKm!(x;8JD*6L=JtVF5FQ&SJ@8g^95+
zWU-|%nlmsl#Ir#|%?cK3pu}5Lnx~Mckd|2j8eImbRZs&5lzH<(c{Wu6TD~e|<maU5
zfjbL+n(VjuK+Rfk^98ICT)Nz1E-KA~%Rr4O0)-p2yA8_ykb)3fSg}KWzYr8epfF}&
z;$Y-r;$Y-KDTtt|P(lZk{Xx||xF7<x-a+N>0!Gka4rmAmoVr+NGo&!hWnKtwB!c9#
z*cPzYFl2Fn(i51^xd7DR1$R7a7_zuacv6^qL7h<0a13iz94yp8EiX`(Ma#kPfC3c{
z;D7{|S74QBQyO5|%#vb-{DP9q+{~)fqGG*Ykg-I)U(9+1B}FqC7#RFCc_D$v53&(y
z5cL*wesRGqmaP2DJaDZB3DQNNY|jrG3<IS!XihI$0uto_&6Sj-rdQsArdou(x7dqP
z3vv>ZQ;R_1SA?xh=0i%usd*`@K+ywI%)q3;Bmm0u5?Ir4YF-LT9D#B?sGtDR#h}I<
zcs7ACg{g(3grS2WixD&g(ac!OSi)4o+`#}!!_DA<tt_?<hAj3J=3bUsCXjdyQ#=cZ
z&f-|WSp$wuD31%2`oU~&P#1!cAq5h1Zm^gG4TqPOK+55Q)FN;P4^%=!ieQCQ=uC?O
zxC~Z+rhia6Rwzy_(L?bKsImnG0(dA+4K@@fSIY<*%;{jrVoG7CVN78ZXMm+T=(q%P
z(K=8u0q%3J2i3^TMH@i^n?MAp^#U6qXH5f-Xd)s6RGxrGEVh7@f;z-i8Ylq|DwA=H
zQD`z1!5c-OWCiY?fGZ17yZ9DMNo7GQIIR_dQV~LJ5!kL-pftn}4wU$k{P>jAvdrYv
zqBS6AfD(fyw88)lDo62@fv2fJl}>67q;>~;0p$8yT+k{ZH7^C^-(pZpiGfK1G?2{1
z2g>Fm*gS+1?x36w3VU!)2Ne?Fyv>lw2%1>&1LuU5j9?0DS{BOBi$}{%dqC+KWF!M)
zl?;mgXbLr%iWY<X42p7;fj^LD5C*#mGzticd2l-x(f|Y5P_!2m+0f<}*cfmn0Vcp<
z19BNS^CBio_TzLJnnK(m1Zo9?$~I8999&m0K<b4O<}6U-harnKo(&R$@$8_+;z}mJ
z5KRtnQn)1sF2CZzEsuCmvjb#nYSAqbh<I6I4w7gQCn%AyyLpCyDpBwdO-br4rnJ0U
z9AE=7^U|?JNYP3L28Kf*e}Xa-10xqB7bqPdl_;q73~rw)frlWnpfj4_!3TychAftN
zR;d5LeTt$3plD!+cu*AL5l{xl^ht<oFvurGM?g`26zn;WLR|h>&A`BL0^|=+5eM=I
z4`UUkFQ6u1q#jV22u?kqITujVJ&U1?u}G{0)KF!rVQgj!2G#QnHH?xB!3^LAy(TlL
zhYf1X#mC=bPfsmLEGa1h*TT>?{w?P0)JjMOL?klq_;{$&`1rGM*D`^_TojS^kyL3i
zK|%m)qK7*hREpFvfC3*h4+xt1hdGm}=oBdI*h&jh5=&B{Zi5&Jb{3cbX8@k~_~O(O
zNINwCGRSeDJjH<0q(oMtDOJ<~iknUl0qQsv^?+F2AOh4qyv0#alwXiqR8o11B`H5Y
zr)WAz29(T-Km}S6s3is-Xee3+k^oH(6|Duawu6Y>AOciA73~ADKv@x-R*OIpT?7id
zqVph$UJwEH1egGKEJ5BW0(GMrLB0S*DF>q%2NR<Jvj8(6vk<!wn-C`-2OpOhh|k8y
z%Ety)t*LoSu%Hq=br=tCpBI4w0X(=FB@C9*%gfVCE6oGXASUJ%8Gu5U3rQdpB%lIT
z3Z2!{1CQH+<_{A=(~f$H(6tocS<74O`NevmMFY23KqR;_y2YHBlw1Vr&lQ1&aBs1I
zmQ_Hy6*3qO)B`o@iuIBbi&Kk0V}iF7a4CWJUO{nyO&PeLNG;L>r-5Qfi6RavctGhw
z55$Pi%qvMPN&_`lKxO|ewvx&MP>uwRG!%hj3p9{a1S*3egG<Js&|oV_%u7iuE&>hI
z-4e`AEGfvzFUiSF(krM0v2L+J=%Ol+V)o+H#H9S9yrLRV%z_d&IC7A35;&qkaeRxz
l2GSF=1Jz2!pgCvIBq(S?l!sA(kp~Harb<B#Sq>I~8~|;~>=yt4
literal 0
HcmV?d00001
diff --git a/datasets/__pycache__/custom_dataloader.cpython-39.pyc b/datasets/__pycache__/custom_dataloader.cpython-39.pyc
index 4a200bbb7d34328019d9bb9e604b0524127ae863..99e793aa4b84b2a5e9ee1882437346e6e4a33002 100644
GIT binary patch
delta 4619
zcmX?F-dw?#$ji&cz`(%Zt-m43A$}sC3}ezpZ576P?G%wYj47fiVlAvuI_?Z9;weHc
z3@JjXs?E$%TB*7z64^`>n2O?37*d#Wm~&a8SQr^n7*ZH>^rG}qSr+IoWXLi|k(6YJ
zGE9+5HJr_mB0ZNm$_UDqnadny3}wsCWsWj|vgPJ7N13MTWtpX#)~Cp)DD*NhGNdR*
znWt)`8mF41D5WT;ny09wsP?i&S)^*GTBcewGp4ACFf=npS;6?JRuUk-bt-F?O&0qC
z+f<8Gt5oZBriDx~Oi^}G_Nknyb}5XhDyeqOj46sK>b)#c4iK>v#T1QR<|v00hA77r
zdxmtzRC5uAD5q2h7lv4aDCbni6wPeb2~0%}j0_W)3gc3pW0<2{qFkfgQr%PCni&}x
zQXNwy%^9K`ComQ2Y@Wqr#uP5z&XC5K!kWU?!WqSq!k)s>!Vtxp!kNO=!Vtxl!kxm?
z!Vtxt!kfa^!Vtxg!k;40!VtxoA_xlVD6SOY6p<E&C~j~hFhudRGq5m3@dh(!N^DkV
z6=G6zdkKoTmkcZn49OgzD1>4W1_lN}1_lOZ1_p*=tH}xMYK&}?_bZA`p3BZ&&tAh2
z&z-_p%TdBp!;r<9!qm*vpH<6Q!n=U4hI1if4YMReEmsYf3q!113{x$4El&x54R<qR
zmOw3c34aYk7DqE9NF;?>oS~UfnxU4rhIfJBLWZKO8jyJ`HQY7~H4Is-3xsNT7cw$3
z6lR4pq%g2BurM?;Gcx1}OpaC(tQQ8UXl8U_h?S`2t6^9mQX*Qzw?J$mLo;Iy>p~_*
zhQeDV;tM27#Iqz*SW;Mf8PgdTGS)D}OQo=-u(vSO@~3csX@L|@FfEwE1*U~kxWTk=
ziFB4siFB513J)l3dzoqkN~E*oK;fGr*vk~t%vd6wC7&V$3gBL*66q|36rmK+6p>!0
z8vYXLEX5SD6vh<s6p3DDuu7#G!4l~#<rK*jsTAoH8IWj=P>FPwN{VcXT#6Jpp|UhH
z)_|;4ot&wxnXXo%zCfZxqnV*ZeSzjehIA&d^R>XD>RH+;Vkw%vj3w$>IxyA(-4rd5
z(-$(PGlSLYfz|4#@PMIqFEffs1}O}|44OKV_i|3-4q#wl_za5qDxJv}C50y&2(eGL
zS8<rUgUf<3aPmj4yvZN9tR~OoHk=&K?LJwMM};wDvKh~Hrk5g<`FY<kHckG+Ys;i*
zI60iJdGiy#^^8oKB9r$CNKU>ZV9n;mz`#)CJy}+;j<Ia=96_ncI|SpIs-&RyNN}+;
zFfcfSQklr)AR(>#q#7m{hFGIo_8RsDEDISHu$Hi8u`gt-Wv^jdz)=HXF)ZMOvRD^z
zfmn<Tg*th9V4WIRbYfGVR>Hl2r-pSQV=Ze9YZh+__X55ehAjU26s8oW6y|i6g^bzE
zMQ@O$QkX$<Fqsm88WvD~EfD~z0kgOl2-dKsuq<S%Wi8=eAXLM=kg=AzggZ+(OC(FQ
zhPj3zODu)8m#LPyhP8&dh9!+Tm_d`x?-pxGQDSCZku(DX!%Jo}1_lOAmLf(FhY1vN
z<{)A^NR;swcV203d~!}=adB#~ruZ$6^wg60l*E$6qNd5GgtbB<85kIDv6g1$l~~;3
zO)O2%P0cHb&&*9sPc3Q%nNk3jFUc=T&bY-|T#}faeT%g?Be5X$7E5quRccWdSiYbr
zGq2<pdqHYZZfQyCWH*sU{Z3HyXo1oPlK>+NBNr1FGY6vpBNrnVBOjE-!^Xv8!pOnM
z#mK`r`HzSXcLW0i!%C)G%sHuflY>Q77}rg%5S_}nW-_0cbYdRJHOU|%5k!F8rO8&r
z2V$p#h%N>OhA58m;>`HG#N5=PH6V8}XQosX@q%Q*1gLB(Dr8__*a&hPs0?EOxrd31
ziG{IBadM-K_~b`g;`Ja24p=cZfsuisgkb?=2SW<uLZ%6f{R*{A=?t~Z=?t|jHB1Yb
z7BaXn#M;%e)-ctuN-}^VyOl|jp_Z+NZ2@x$OBN`tu$8baV6S0Y$XLr(!@Ph4#A0Nq
zVOhWdN=b|ig$gyyMFk~n3pgQ>U!Nya!j;0@!O+aa$WX$zfV+kviw9I3rLZn!sb#KV
z&f-mBOJM?8Q^K}@FP$NUeGwx#<z2%f#{rWo;p$)jr#LQ<Dlm&5oW6LPnPQk~nQPf=
z7#0ZBFn825EM%(XC=mqLAfOt9qeQ5NAxjukZZ<PEGuCp}aDezVoHZQc43Z2f+%^p1
z3@J?TvR25DVFF{ZnE?YM7)CNMGL(qaFl31~GlG(HHp>LYqMj0Qh)5@62jc>X4hB$}
z4H8)(sn)@m!XpV5XlG1g1Qnz$U<+M9#WX{_a0*i}gC_svyW&kQRf4X$1tpaV&iQ#I
ziJ5t+MOF$$pmYz;7)79LpvjCR0Lm_!%n(Ts8!R<BUc#Ai>*U1}hTJznNkxr;fuTx&
z@<mB?uJqKr)QW;4yP|EA1tgQzwt}QVxeT1TiWETH%^+e6hyW$Al}tsuC(o4BWqdOE
zl&s9;8XeKezhyb=HKlH`WSACI-ePgf%t^h)lAT&vTr?M?36%Yc=7CrXKna#BttdY?
zKCd*lpt1;*!iyGyq?UlBSc?+#Qu4t0t!NQQbTLR&up~b|EwdsuCB8VbDm6YSG5r>A
zd}2yUd`W(Ma$<5u>MfSM)X7;=!i*ax*GVZda!p<!C8rw2Rg#&L3eG6Sw^%^zA^}kL
z;R0oy_@u;=<cy;2lg~*>xgQ4E1j;wJxZ-p26H~xN$1YIzF#=_6R#4VvVdP;FVB}%s
zV&-DxVFB^D7=@TDz!@7PF2iWFSzcO^Q5@`}4WK;E49U2lhC<QS$p<xsSx+-CFg%-_
zFH^y|XYvafMaFNFxn!qlfpT%tEl@$mRJ0r9lszB<lzk!D8sxH~BMb}-e<z=k4Q1Ro
zSwT*O@%Ls^xpqbeaBOV@sRx?^imjE*Mc~3C2_$+I<Pc~H0y6d%Q(8e0l6tVSW-%}@
z{GTkMV9dxo*;7GVWgST8CJ+HCM~XmZ6oH&xbO$7M4@AtGJVQZNbqa_HcAgjm0|UtN
zVo>vlgOP=sFP1}<gPDVqLud0X1qWtE&&kp%?Mz;blM}g3C*M?2X7YwGja8MId?3tT
zRb@us$s1MUjX||ukqgLhR}kR=B0NEa4~Xys5rH5g2t<T}h_K1dYEdrPAZ8AT$OWYv
z_M+5+oW$f*NGf9miGw{<1Y#9~RB(gK{IbmA%>2BfW{_~{<cDe!qU9iF35X~I5lSGU
zVlt<?88<hG*#aV3C;O{QMSx1=qEwJstSPBUrRld=GV*g%i+VsRdO-xZmhS_x`a#45
z5HS%%Oac*;K?FD$K{=}k9F_?T3=DFc52{NsCQgH-3UCqvMX)A!5h${Xnn2oSgNQT`
zF$YA<1rdm>HXp<V*KS~qtdPXG3M2~329S)s799Ve_)eW1qL~D)n~K(h)NBOHfD%E)
z<eQpJZtx;NleGw3{er8iS0JU})CNlSSgNYb+;ngbgcK&*!G?n3pnCI8EfYq@9h2W_
z`z!ARnF97dDAN_~gLA>bv7do~p?`9Mjuqp!$+LALoWV&RTtI;mToK4&n#@IDQLq3w
zb%RYi05a$hhyZiJrh)R&tjT7&w)ICr630NqIS_FiM4*KsIMJU0Nt^`{V2^^bb`dyl
zfnvSrJV@dKD4N(IAq=jkFM>obfe3I|ZwIk1g9xxgu7FrvVCR7HR@3A^y3@I@fdsFE
zh#TNg0@*!h@;1FBNP!J1ABt{+Wk9)b*<@?|<=po{${&D;hhP~{+E_iA)xa$283O}@
zn<j6OFh~-VCyPWutj8b%l)j2Uab5(9UT|9hJ&%J7DtZD^40bS>0H;#0qn|P`FbGYa
zYaq?|bn;FEC&tdne+;6H!0j5Aq8N~_SP&5hBEmrh5H#_jOHCFu6K6cQd7+^pBOfT?
z-eLn4|M|H^;gcU28Oec_gQ|^N?4@~`Y57IDMIdu;F(;;^6va(8Fg9e|J2}bNg7L)U
z0CSbe`;FDPKY^_L0wyMZG`5fhwa@l5GB6agFfcH1FmbULfka+1FflM_GEVk3nPFQr
z1Ed;MooccdfvYDDkRw2iJa7$ki#a*5<Q8jHYEgdiE#}<JydqEn&}6;EoRpY8S;^E=
z`V=VhgY4h|RTj(=j2tX%j9UNMn2Sz<LWdh<V-dt*lSK@jA$jZ@$ky+m01$o+(hF*e
zF)*<)f?&~)$<}6{lz)Ph`~ne7j0_BrssJ1j;5ZOsU|{$(Ilx?1iItIo!B3L~Y6VC=
zH~>JdDPjZb1g8y<GhR+!Z(hS_K3T^?iz6>TKEEU*wP<pJg?BwDxfFp48@D1**{jJ~
zlmRjbT)2QsxJ(ci)C^9aH4_A0+G{f1VlPh2EyzhNf;7L&^NX@mi;6WFiv&ToN`UeN
zV-Y(e1B0eV5jeAfiz#qKvlk?mWaMNfL5h+EleH{ElfcFA8&C}`3@(-R^78c3O7oIS
zGV}8ibBdOOT(ANh)^4|$lZ(r4F$Z}16tRJX|A2(q3KH{D5{qv!6{Hk_>VPP|#GIs3
zP*Xn<WJGaM$>fKY6&j$#3r+>Mm^1TAia@ak&Ly{)@{5aJf=mM!Y9A(dTDh=+;<_k$
z@?|Rz0UnU2Kt(B}<XkdY&_sN)k2ODI#^e}lBhF%>dImiKX#xJp)2$Wx*;yF5z)b@-
H5X}kzkaKZ=
delta 4613
zcmZp!I9kq^$ji&cz`($;es*&bOUy(*8ODN*+A56otXbNrIw>M^m{LSj#9CORbln+J
z#8ZS?7*d2%RhyZkv{Lm_B(j+%FcsybFr+Z$Fz2#Fu`n{EFr+Z%=tt?NvMexI$dF~2
zA}PrbWt1Y7YBZZ6MS3oClrfYoGnYBa1j?43%N%72Wy{TFjxtNtug@}1HA|6CQRro2
zWJpnrvPjiPHAyv1QA$xxwMbD(QSD`mvP{)UwMw;YW=v5NVQ6NIvWD?fttCKwn^e{;
z+bs44cBz)B)~Pn>ObeM}n4;{X98x(`?Nb<2RZ{Jn8B^3#G<sR093f&U>M5GN%u$Xh
zz6?=LDGce1sTLv(QO>E3E)211Q7)-2G0ah}QEpN0sm`fP3p`TY7BaFjNHU~Inlmsn
z*f30BDl|y-Om%K%WMoM7Otk`80Ag?MVKQS1muP25V@zR7VQ=A#VoBjh;cQ`uVol*n
z;cj7wVoTvk;ca1vVo%{q;csDx;z$ul5o}?I;!F`r5pH3K;z|)o5p7|J;!Y7u5pQ9L
z;%R4KVTj@lX3&(}EXgXwr0(bTQiOql;UxnL14A+=C`O@Jl!1Xkkb!}LnSp_!IAU@j
zyBZ_=<O+7#dX5@~c<vh38ishD6vkT465bkyEY1|BW~TnETCNhl1^hK!3mI#eB^heD
zYq(t)V%=hxYI$mTO9X0oni;bMYk5ipY8bLOni)YNDa_&w&5Y6vwR|;v3xpOj6cyBf
z%wwtHv0<oT$YNa}T*J4Jk&&UWAe<qEfrWvEp}v`!ks(jOkfHdD0Rtl#Mlvuml!$<o
zH#530#7fli*Dx#)EfK5XUm(7ap_#FUZ6OmQL*auGi3O4+5?N9yEGewLjOmOE8EY8g
zrNJS~P%Dtak;2}>P%D_i38sZoxWKe<3OAS*DUr#NEs@EpmjeZP3STc%jbMpPmV637
zD9C%6ni)%EvJ^nUo+8rAR3ek5m?D%S3S!p?l*nW$rHG|4riiCV^fH50D%S{=$YiOc
zNTx`oNT<kvL~DdgWU^FKWK-l)q`*m-rJ1n?WUX3?Vv1sl(j2A~<rJ04Je;!i>LnTr
zBug}#8A>!3Xf0$&X99ag8!W1krIR9-qS?z>qLHNwV=d52(E@p6A!9l-Sgk%-tw9P8
z7;5)2qnKn^!;r<3q65ySAtgo&j8k;Mx=d0Sf*CaRCQsm;#uzwRoomiyBW|n7f4NL2
z&*gTXT*<A%7&3VtcNSylWKN#zOhpotAMw0lY@U3c*H-QpOG;*5ew9dCVsdIyetxz-
zh^t$cS(0RErfE8vkFS}7BPlUCJ1IYJb3flkMz<mk1_p*({Aoq`x$y->sU<~;nR%%x
zw*=Et6H7{qQsYxAN{SMbOY(~}MT__t7#NDuK|}_KNCOd>AR>Emgn%8JCy41ad5%CG
zW6fk~L8-~cf{AP#EIf==Qj-n2q$Y0=6y#!KU|?_trFEIfrv<g@t7@2B7-GF@*=yJr
zuq<R)z*@qV#lDcSmK~IKQ<$=uiXuSizJ@)8IfbQ{sg}KlZ2?yeRK)^rD2sIg4~WIc
zP^gm^K+wc32ovkUMlvkm1!)Iaip9vf622PNX2ulOY^I`DCHxCO*-8+Unp4;pvedHX
zah32d5UOEa$XLr-!oNVchIt`lEprKfmPnRpmRJpQ4MUcA3P&$fEprWP4RZ}k8gnp%
zCa2#mR!~R><Q2t%f}g3NM3d<jTV`%zdTQ}4zM|BU(xSZhlGME7{G#F_Mo`HCDj<ra
z85kHe8E<jtmFC7L=Oh*vrxt5U+~P=2Es0M_EJ-YCo~$XX6%+w7lC?B5uf(D#8x$u!
zAOYUQ()8TaypnjZEk$i0i9!%j0g`3TNzJ>(T9RLsoN<e_C^0W3uP6&7&RSfOn4CR%
zxp1R(A1K-kK;;&b03!<{7ZVpV2crNZ7b71N7l?%MKr)OxOgu~?j9e^Sj3SIxvXc$D
zL?++XG~y0tU|?9uRJ4?VfnoAG5f$F8ATyajW-?Yuf>nJMDP>$gIa5?RHXr1;WDt=E
zB0z4{WJ5}Wy$lQtQ5@yPnelmvxv53#LC$5)OsOd11t|v;ppvzyh=GA&2gn|9P09#z
zD;MMBhoVOHpn{kKR`?1rGBA`dEMV+lNMT&a)W58jDV?E~xrS*0(?SLphFFtYmKvrS
z7D<K_rWB@DCP{``)*99Y%q1*Ykf>){zz#_SHOvb*KrBXvLa7qA1)R_X!39nV+|5ie
zOts9lY&8rEc<O7II~W!+)v}lHg6kgU6qXi_622OSEPhZ4+sxR^Sj$lZ$_o5795w9X
zpsI}3h9R6Gg$Yzd!%OHLC?zyFeKs@JFk}g3vrJ$t;wurZVaO6}X6$6_U|b;5!2l}g
zK_UxSMLQT%*d{CL2=nv1fC_Mic>WY7Q1YC-O}uGxu#N-EN~WUilMhH3a-Ro<sTu<V
zLzVtyLoVUT{F1VaJ0|N&CaY}+$$*k5I0lLoKwMC?7HtEuK+&+0sc7%yZIZg&w?HCb
zLkysXh)mw6%M~hiizUOfpz;=rTV_t`Etc%m%HpC$Ahn>BT(lU(S_X<;E^sN8SDIT;
zSp-TBMJqs}f+hL!X_*zNDe=XbRjKhwiRs{^JGo0rjdAJZjZ%tUoQPC&i#@|MJ}omR
zHHs6Ia*`5Dk~40xlw{_l772h1V~j7_38FbbYT{w3X(bb+*gHDeP+CN3KPWLNff6k%
zA0rDR52FAh4<i>74<i?tCBta2xm;S2Q5@_<P*PvX3~>NB4{x9Rkz1Jc3<Cqhoyo^#
zDj4@oc92zMd^0&ocA5z&4Hw-3WpbvXy&!Mz0}-1+1SlYj!0GuI0|UcnkPtK^3uMJ6
zua{$+>>;Nw2387Eb&CsBDC8t2rREfW+}tbI&d3dpC{RWyT0dD{K~(xWC@(@Y6v&=i
zOlbv0P<aPgakirQ3=9lkC+8>_GqLzhUZ9|@vJqtK77zi-Q$-*Hia;JIx(5<_03zm3
zey<?QIDIm&qBLg`hdKu{2Rn!CW<x~>W^Q{%28JTAn{_51)DfQCuad>+FnOJ-`D8;?
zB}PXut6Nox(P{EJ)ksF?$sB6&#u6YiLG?wE3y9?oB0NBZH;C{B5dk0~5JZH4h|tNi
z)S~KhK+Ieakq3%R_M+5+oW$f*NUC84iGzJm3}TglRB(ffzOu~X%>2Bf7LafmNSG@m
zKfu$+6I3qWVhc(IC9on;d0A8nQc?~gKnbC!62z(k5!E1Ka)r9Gb~A{@4I)}W1gOB(
z6oe!Mv^unR@<DYm?tYM}2_V9A@<(+Exv3!LG!QWzL?9AC0s{ks$Yu!*DaOQ^khB6$
zSfIqC$qkORBCy>HK-$2S^g<8|Ek`W@iGdsi&bOMZ5SOh5iGp$hBqMJCCst6rM^9d@
znFKDsi#CDOYzE7K5<}i(1Fa@Ec$U{>EdmwjMd0H54M-`z;yW`p9Z`H2?F1VNikiaB
z0@@~wjJqbg>i8?~2AKl({~i$Q0Ehr{K`FiHAOizK>*Vb^R*X9)f6|F?2B&jyfdns$
zG?|OQqF@1V5(jHP1TyFdhyZiJrh#(Tgvq_Sw)MwB5+^_exbQm(Vxfg0IPsqaNt^=_
zV2^?_d{GjJn+hT>fQXBrXkv$iFgWfnfkZEZ2yj?~ikzaWATHP;*FY>Tuya6ptYUJo
z-gNF8Ai<j;;ubiRKz2`>%&VWoeH$bQDm9Akf@MItan9sP`pdZ=f@B_nh{s?VP>x<W
zImW=O{sjXAgPSIAkuXRS)HNvr^+bxEg1DgcRRoIjB2e^#>reE&4l=0d8AvhM!C(TM
zO2LkP&cMLH!N|Z+ECgn-aPzry2!o3PDGn|U)}rSOlLZYO8M`J28%7&}n*=OH(ICTO
zKtwEv2m=*P&@_oIH965toN@c+ABKjEe4rZQ78|I+lAl`?Hrd?RNDizVRB7B|FU`wL
z%P-0;0-1Y@IWZ-rD0XtYu_5FB$vcfL823$HV6HM*%tW303&_fEU}Ca|iG@-Z0|UcW
zkOx>m9$?~PF9M0YWME=o&}6#BoRpY8d9KL}`<LK^J{J<X;2Mep<dTBaqTE}o#TkhO
zskfMu6H9KfR;3o@7vEye&CCNgeKc8%xEL82CKs7nN}pz6V7LIXiU(AoFiS9UFtagg
z{byq;Iz3sy(1{ylaS_CQld}z-Avy2|$f-X;;UIhuq!-kPVPIlo1i_+Tle^76DgOp3
z`2!+YK;<qthQI_khCmVXbn*goRV8*t1_nP(7N`{<_26&-xuysdhef>Llml|coyq^r
zYZxsj*H~zAl_cjD#g}CkCFV?CXW?zZ4YC7N7`YXJifm2RB18`X)J`h`cN0J@%Jf+?
zLExo*5hw{x*0Ge<5oH9mYKu5Qlvfcr+kp!#aNM#NB$j04WF|ohm1UDFEJKsP<?}mG
z^G_IDZ0qIa>7|wCC6{F8=OyM8tpfRCHK_Pzbi2izTwHdGIl$AW2-Ni``UetbD@e>s
zNi4p_RFF~xswATL5_6JDL5<@?kP*d2rIYooDhxo$7##7p*h0Wf@mtK9c_l@l7z5{<
zTTJ=IMXy1Yfs46Mlh0eZuz|CG)MQC(4_Qzgg9}%1S$c~tu_!$^vEmkMK~ZL2$>b7i
oWyZA0Gpvm`Gll9I^aP|ucqU)AR^(-4VdP@uVCG_!V&q^10C3oF{{R30
diff --git a/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc b/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..20aefffd1a14afe3c499f646e9eb674810896281
GIT binary patch
literal 11037
zcmYe~<>g{vU|{H}-;l)a#=!6x#6iX!3=9ko3=9m#W=sqWDGVu$ISf${nlXwgg&~D0
zhdGxeiiHs(#u~+%!jQt8!<Ne)#SRu@$>GT5jN$~dS#!8@xuUqhY{ne!T%IT%D4REm
z7i>0PE<Y0kBSWr0lps__C`u@WA%!hRI9DV}1S%#PB??w879|d*C88w2v}BZ23PTEe
zj&!a}lnhucM~-Z+T$CJ`&6y*gs}Q9KW;5o<M=7N+q;TaZ=c+`hFfzC^q;RM3v@oRb
zq^dSEN2$3pr0}NjwJ@acrLt$KH#0|RxHF{irwFt#qzHgjXr`#AYNZIxVN4ND5ouwK
z(spM^5k;0y)kzV{W}3iMWP}k)x>35REDQ7&GGyteh)Xg=8Kg+08q8)$k(|pMWe8<U
z&1H@<g0iLOGDjIh*)nsPqfAnDvrJP>Qe;!)dYKp*Qskq|QZ-VIQjJp-QWR6oQj}7Z
zd)cDQQ#n)3Qy5dkQ&f7HqbyP_Q>{`hni-oJ85tllsVb@F&5S9kDQdkeQ5Go-QPwF8
z>5QpnA`DSBD5@=77-IFJY*TGxn4|2X?4uk~ZBkhmIHuY!WMpNKWJnP=XJBTqVMw(G
zv7)ReFc$twwM?}E*_6T<%%G|95|pC+G#PKP1SFPZfasLOl0;3$TU;)QC5b-yi7BZ?
zAkmV-lw3{5TdbbBiRq~z>5%-8)V$*SqA(*(##<ter6u`psfi_}MX9b8B}IwJCHWw2
zT#2RWxv6<2sYS(_jJG5_bCVKt67!N%Q$U6nr<MdK<`(3n7A1omfsC18Nn3=0fgzP4
ziZO*DiYbLLg{g%hiaCWjm_d`}mQYY)UV5rueo<~>PG(hNNoIatGDsiH9A*Xv1`yvF
zltn=1)G*dE)i9+nN;9M|Nid`^Nizg9XfpeygEc(JcX$qFFg&nd$)L%6i={X<C+!wn
zG1!H-m@{+Ji+C9r7{J6YPyLMi+*JL_oWwl+^8BLg;)2BFRQ-^m#Jpl(Paj>E)YJm^
zqQtZkeNbQ)r<N4!CzlqN<mbj`6{N?5#B#s^t5;BYiz7ZhIWZ?EK3)doP8N_G7+Dyr
zq%nP`2UF>%$pX@rmzbLxAAgH0K0Y@;r8Eb`=82ClEKSUT$P@{|+zWOzgb)OYaWgP5
zNQ0~cg%1Z~5r~fw2Vkdw6eXd=K@@X}csoNHV+v~uTMK6tOA31mM+-v~YYJxyR|`WF
z8zi$vu|qOx6bB@0MscPHrU<n#L~(&LB|{W<ifD>h3qur7I|B<t6mKwtro=4)uK;&n
zPaja)fP_!tFGdE2pwxn*)Z)~<l46C#JcY!hVsP?O@NjW6RR9HLacYU4f?Ix(LUCqZ
zdQPf>hp9qxeqKppW?pKMq5_wKf&zr_%qu7@Q7A|(O3W>`0t<tr(lT>W;|mf?GOQFp
z89<>VU!f!;RUs`uCndEAW>98cI#?8Hoq|SwUaCTVkwR{1PDy3~$WD-BKt>g7>cNf9
z%}+_SQcy1`O3cht2Ps!CNi8l>hs%Tf2u=)E3Z7{SAw{LBItn1MVui$<9JqZ7nYp>C
zDVd2SsX3JjnRzAo3PFy(dJ3V%sR|_-nZ*j3X>bEU{sB22=FH4ug|y6)Vu-HX)ZF}{
zN<9Vtl8n@%^2}n8WvPi}P*W016w(rNic=L5^HLze1u_hTVFoIcXXfO9-KUU~Se2Pm
ziSQ`MzWC(C<c!q#;>@a4D+RyO+@#bZh5R%~EP(??p(G=*L?JmbPa!E)0jxMhAu|u`
zO)yVEBPBI0u{5W|)>u=I3+yMim!JypB?AisLo%qc1LX}67G+>y0HtG44lgcYVqhp?
zXl7WzxR8O7A)TRyA)YA{L^6X(7O)6w4O<OEJX;NO4MRM84Py;MJVy;v4MRL<4NDC}
zJXZ~S4MRLPSS3#iV=YGsZw*5hX9`m@Q~#V=&Jw-_{56~l8Ecp&8EUy|xLg=wD`J>x
zxode!1ZudO8M6dyxl06U7_vB;85amGWN=}K4XfpaiPZ2)Go&y}GeFczGt~0c@GTHt
z$WSz)M5Klfq`sN4=vEC+4MP@74Yv(L4MP^|0?``2g^Y|0g%fHRvUm`@!h&#y6b2Rs
z7KUbKMut2ALx$pi1`Lc~7|FoMP$CA>-OT915Gz&7U&F9KyhNgge}UvehGxbZ)`d)r
z3=<d&g-WCrNS8=u$)vEPu=X;hGcE*$hinR43VRDftw0I~m=;Xo1k*w(Twq!_g&RzZ
zl*nbtm&j!)r0{@pcrR0pV2NCoA}E)q2=+2HGnU9@DWwR3a(ORPiCmU)icpGZibyY0
zjX;T9mP(2kxZ0EGWd^HMtr04b%Th~`Op!{FPLTnL)(Dr#WvPQI&=jc@`4oj-mS)Bp
zkgXaiiYbaIN^_V}lv7kvq`;LUYcpdybBa31{B*Vy4G>+!5U*LHwLrQ=yP2UxYk|%}
zhIA&d4|Tz!T3LE2Vkw%vj3ruG`Y_f4gA^^0?-nwqGlSI{g4G(O@PMIqFEffs#wnnh
zL&q=V7GqHqo0E^fvv=@I21W)3O~zXsxrr6vT#(ENN~%x{VzV+ZFa(3jQvn80ip^q7
zVXR?DVM<}{1=nYOMLY})47b>W67y2>a}8HA-eLh2<13kNF=r+w-r@|-&rL1K%uOv`
z$#{!7u@tE`Dgxz&TO2kyiMdHBiFOeT3=E$^_Eu@5m#2F1@tJv<CGqikHaYppi8;k~
zdN5riJ2!pSOb|F}4|B3nku0cgkp~e<AOc=pf*n|-!oa`~2@+RjU|=BGy60!lfWRBL
zb;d+k7e%&pMe0bAVM2se(bTcZlnAS0U{+}|Wq$tu|Nl!+K;L5Zt@H>E(`35ET9%ko
znpz~sz`&r%lx_?LukF**z~GbpEq32Z_r%=XM3B;S5Cg0flKMD&E8TK33w(mW&bS5T
zhPu22C18-wp!^a`15M^zjM=w1vr|(Gz)2l!fF@J=4!PFXfjj!_)2l_+xj(F$VqXN)
z7Lb{moT|xui?uj8F(>sFOLAFa4pi666b1%{mnT33sF8My)3?$!C$YFBGg*`67HeWo
zK}O;&*5b_c+{7XakjJ6^0{bTIoqakaU3{?DWV*#zaf`81lPU8uC{T34{(t!g#0E8G
zK%Q|-OUukl)nqL)2gNR1QGQ8cN$M@u;*8Y9B82H6fBNU8`lsDu^{sR)P1j_+#aMum
zU_q4+I6H&ts1k++j5Q1k8G{*CGWuyU-D1isxW$@SRFqf=(wqosdV!(|krlc?o&jYA
zP3Bwt8KtT5#Tl7tCGk0#xtS%m_=^)uPzA~obC3lzO>VIkmlmWJff|C4JbjB3Tn}gF
zrRCq^0yDw&>Mg$FoXnI|pTwlp9GA@Gl3T3c>i8BvLJ2q$Z?S;dE4Mg7vfu{MEv_7p
z9tfK=FFrXZvA8(3_!dV}Vmd?zPik&KNo73P{Nh_I`30$Yw^*`@^Yd=8fZ8!dpw{Ir
zmg17s+*|CKC8<RznMJqQGK=FuG)qoqamg*V#Dap<yp$qP*?5aTJ~1T)+#revd7}6h
zb8$)0E%u_+;{2Sl)LV>>x7eNYa|`l|Q*ZGH=O*Ulq!yR>CRPMzR;At&EJ_76J>pA1
zP4Kk*qTJ#l1yCGw1*aC4rskDoCg$7{bxcVK$@eHtb<9f%26a`yeTG}S?x0NIoS#=x
zln8PXC<Eja-C_gLr75>q!AbrWD<nnmgHtl7`2<Nzw?v_d72Ng*Sqw3h4Js`FZv7!S
z>lPQ16G83nTO6>S)Ga<xBqOA6i9qTCkNl#{Do~3y2jnDhzb7R>H!QWNBr`b?BF5$e
zF8gosW#*<MTP6WDp)?iY;V@9~aEmXaxC}|9rf3l#C_#b?XmGI!sz!=HrP?h{P-KGR
zr6?Sv22}ps;w(x{E-fm~1Vv2REzy$H+yY2Kh=-(_l+5IkB72Z>&=5rwTXKFzeo@IS
z{)+hId{CdS1k&n`;smE2aFXCGhE;L5xZq49B(^aU+XRVix{~o0XFMc<#>Yc48>mh!
z3SwYj&;eDsb)fo?k&B6q5edS0Of1ZNjC_ndj9kn@j8cpoj66(SjC_oGtbB|T%q)yt
z|2ddB7@?4fg^`C*fRTrhhf#=;g;9u6gi+=%6ARNHmS0RfOkX$@S(sTE+5T{_voLZo
zaxwAz;a~&lVEb1kiL)|B>9vB|xuD_-+-p6=z`#($uz;b4aUo+2b1hRXb1h3cLoI6!
zLl)x#riBc(Y&8s7%qfi7OhrjGOfC$uMz!oU><d^HGAv*%VasA)$XLr>!?u8<2Et-k
zzzJosF5m*O7#Rw6^7OzuHL&Q!rarBNdjU@k>q5p_))MXoyfw@V8EctKxU=}O__G8+
z?dvST6sBILTIL$o8s-|7G*H`y+3yxBI3H*--r~+H&4t8NYO$u!Espfm5=hc22W3}*
zlKl9T)RNSq+|0bp;*!kdB2Y5C#hO=|TTpq6EwQ+yvLLm{8>E!AG&8Tn;ubG-gdsi?
zG%!(A1xj(LVEK~#qU4NQti>gX$=SC!6N}?Zz~dFSnDX*&aix{!C4)u`5_4{`=jWxy
zXXKZF6T2n{I3L_%&CE+lt+>USlvn~X1e~V971k}z{L+$mh%jeyYGFJm$E6g3+H{bF
z2QTGsamUA}r<P=vq~^xQPXwhxP^*Z6Nq~`sk&B6onS)V)iHnhokq^XT<YMGu<Kp0A
z6k)28!IDzo24U1gph6a$SRnNfC}}h^1~U|Sf$P6p%sHufMWA|WB_r5p=#FHMkIzZX
zi;v#{av;bo2F5B0EDiuGN(NPWpsWPKYzzzx><kPHV9Ra5ZQ2^fX2vYWT4r!MU@BuM
zvZ!ImVgmP@SW+0l>6Uo`OATWUOAT`kYYNjG7Emt-+?r<gd-?zW|Nkplpw$?tOI|b=
z6q}$5{}v~>E{o4hsnBEv*DFPdAh}Qw0rFiDD9$z6AYKF)VvV3gz)@bD84qekgN?ey
zoS9OA763(Q3=9mrL7oSN9|Jhoaxt+mRw-Zx0&4Yyl5IH|7#KhS0t%pF8%73(bcR}{
zbcR~y5{3?jW=2rw)P*6|rk16IX#sNy%L3LKmW7N985gjXFoClvYYE2!PDr+@VOqci
zVlgt*FfZV$VaehK1yZ3%4O3B73C9AS8m5JewM-?vHB4E2&5S9G*$hRCYM2)Ar!axk
z<!O`%q%e0dG&3=Rx{`tmxH=d>;w2mlglZVFgi}~R-M)n^wM-=<pe}Hha5Ga3Q!Nvy
zC|Mv{!_>jBkg1lvL<~H{1L_iTl!(_bWJ!QJm7o#IS`JWWIZL93qlR4^G{VSX!vO9s
zOExpsFl0$(vrJ$tvM7<RVaSqfX6$6_U|b*raREqVfoun33X>$*MeU4fjNm~d4v_gS
zpe`yyyhI9jFoPyfRTXID5j?yD8<7Dg0EP6-vQ)^R0jQk<9wh(`Qh-tmc<7?EASJN`
zG^PO^V^9F+W9UEwTsgSu18QR^6qgj0CWA&T5u+JJpaczWxmF3ff*Lpqptb>cILS((
zND9<Y1;sK-Mg$3furVkhfwFlGL#$K`Q!Qk~kwKE7gCU!#h^dA#jDeA%kO@5W%2c$B
zfq}sfoE|lqZ!zf^++xhU#aO1vSOiYMh=PRECOtJTwW6TN?inb(gBp!hdYH)*UQHwQ
z?qC{=mZNvNKzSXMeT6|zwqgLK*cwm;!dSzY0(Ni;qa?`DMi71tINoX)Q<x<gKq+Vr
zOCD1V<3y%HmS6@=Hn3N~VFu1<;F7Xv3COWaK?KMvnoLEjL0OV1B~z2}7Ee)XZa%2F
zlbu>w3=RTt%lH<T4X9lRY9iQu1nC6(Ri8M&!gcy-%0bEq(8wjEwqQ>z1~<oxK$)~C
z4rKHs5HT4<Oa-}%E3GI$Hy&KkfD02y0W%FGzYrwPTAYzska~-=G_Sa@G&Qv<^%iSE
zQD$CAQ7l-tC^0W3uL#^U5-iD&Ps^-GO#!!(<C7B8Z}Ebgwk7$HQL<Ysd8rj8w>VSt
zN^?_-5=&CS1$0pe$S!aT3)CgL#Rd+d)LYC2`30Jsh?4CVC#X<^)IeM%py5kU0aqNw
z31)+G0=PE1#RU?E)Ko<)L2l##Wf@Qr3Jz~gCP*^d0J0U-TDrv*4^KfyK#4*RRLZlm
zFtV@+Fi9|SF>^8TF@c&jeBcHW52Fw(s39c7D8wuTuE+$KtEBP76iVd*$}AuZ%4@}-
z)CVq$7|IxnBx)GbK_v-84U-K+9%xvNp@vBk)W2c`mt#yd3|U|iP}wH|8i--4n&_C4
zQjDdv01ajqrIwTy<rOPH20~L(EA$jXGE$3D6*7wzK*PA;as^&8fYOjcT7Hp2YFTPg
zr2;s_KtZa5JQNHXo&^;lpn-iz(E=Ji&CE$fDnyJx)h{z7PC<235hxL9GDD&f<a0=t
z0p$#Ebbyn8(N5^tO)&>3>{+-N`IrP4tK=|46;u_$k`qcr1Tq?wwZQ=kng9R?7o#Kt
zC_97ZKY|%Fnf$;RbR{D=-N78S29yv$l`X_QV3&c@1-Oj{N=(d*3=GAf#wi06A0w6q
zEvPLAQitLokYhlO0^wqigTOTlXex)Jh7nZigC}87N>h*)kTu|<2~^;MMrLD$Y8g8~
zQ4jJ9gCqmA=wXBuIEO*O4Jl!Yz#dry@(tV%MH?9y7`PZ27>aNdC?KbU8+pY@g(oO6
zfQuq<h-fl_BkvY_W?o8Waw@dc0^0))CN6{-9H4*#`GSE7OO`;WLkTEQ$pi`{aLgru
zTmD%LH6Y(IN-}_A4%D<OV=PjtVOYQfY6O5Qc!&xnNrqZxcua%FsafDW7D;H20hEnE
zZLu1b8dh+-oVDtXGbk;WfaVckEn<blyh_kGqe5kVsX}>TUJ0lIPE1ZtEiP6_R7lP!
z&C3Q2?1R$3LRwLNE@%WFo_#=dF-QhHCldlP7CdW^o0y%dP+VFBs+bc?Ksq3qHM1Bz
zp^yrm2}n*XPF2V(QOGPtF)lA3JZVz`vJ^fS0nW6D+y=5XQ^6Ke*~4{1au}%MgQO8q
z+SX(Sr+G-e1a+@qZAMU{xy6)L0Inzz>cKG&AwZE_)X%`czzd39P;Z%moe$JrWC8Ua
z`IvBI2vBfnGP@OtgA$|!D4FsVff9BRXtb*cG?j6SJw84qKRG@g++QtP08+RZM1Y#6
zMIc9l%Y`CPp<HwfBmr{8EnaAkH#rqFdkAqOIKg*=lz?3;#=yV;8U`r_*~P)g!p)b;
zq0Yg~!O5Y|A;7`P!CC}Tfpz8&)N3yWg&2JHl@nC$w=hOALuLx1SW-BG88o?y>KPaq
zRx)M!fCk#OK-i2R_J5EwKx5j;pwS&rc?QBDHYjI+Q#okL6FeKk04o0D8A1JqcqZ_i
z3^S-#0Pe)BWGV6k8RZWm0ziZ($QD*`Zp0>q6cpf40F45ufZ_~Xs6u=FzZ7u|jlc)V
zQF13I0|Nud?O=Cb0Od|lBOKHZXIuzsXftIq6rCvn^=_C-n6p?HGSxDbFfU*O<<l5u
zP`|yFwFIP^9W+k_DvQB1lO#h5lQ^i!%~k^HZL=+8tOd<dv6+Lr4q1!~xM~=(xS{H~
z!A)mS^PG7BPYp{A`$DE#7O<%xRlHDDyiipfB_LI-H5?0>YFSH|7x01l<Di@m?moa=
z$`5jB4O1|~N@hPu0}fPktrUPXo7h0jn*7`%P>A1RPA)UL#gSZA;+&t8Uv!Jt+26-M
zC_c#D$;cx(>=tWrK~84LEq0I&P_qaeXQ1ZQE%ws9Owh<!5hw<4F(;;^6oI1g7He*5
zQF<!O9FO2Iqab&umC%+BKPaL>z3N+x8MhdVi`If#Fr4sl(p#MH5z<?npmrp9^b|ej
zLDh{mDCR*4je&`UNq~`$QGij7k&BUoQHqfVj6q^}Qv+mr-cOUI=oH9NXW#@Vaj#@W
z%gvBX46+%?HQ;;+PK==3q6BghsDk5Q;$jEqpO>JO3Yv`I_9IHN10^|7#}u6GGQgc|
z&`5I@6DYD5GNv%pFt&mUf2JC6n+Vk3T)<qzT*3hEW-~+7NifuaSWJ=-ni)#7NHQ#B
zNdxsJS!!5HSZf$d*lHM>8Jn5Xn1UHJS&@>%OHht2S_LY+*o(leWe!j@7o-;D7J;(-
zE#~CJl3T1*sYUt4x0rJ?^B}Dq)?3U;iRoZZK}L-;Q!0vYaf8Q?!OJR&Z*hb4#1|x{
zq@<=Gg%>!eKtmF?pr8VUGXt{(BL@o`qt<^m<|2>`6-q6n_yk!&VL4TF4wNv?g9vbR
zg9)%lKn0K=$Ro%FP!XsQL5b^OU&It$1O?6|5OEhofI|pOfP)E?XF@^2L@2iwftJGI
z$d5%&P)$Y{56XEyAmd><4>Wa>o0ypwA72D2^L{bvSLr~O^uX8e#1|ChXQd{W=qKl=
zr0S=nCY6GR*E93<l8ei#<RGd+HM~Ati@uSmS-g>{S!r&SB}5(}1s(*=%uA0iNleN~
zE!GDYnehRR!SN-@xkd41nMH{?dY}oEDl>GgAWMo<OZ1a7Qj@b0cEp2?(1#9!7Keal
zpi?U{i%a73vx~r^1-%Rm3~oi%puEdkgc#We4Ud6GJi*x$HZTiX%amGF{1Q~`XfoYm
zhpc`oD*6P{&RFymL^<7J$uKRbECLN_LTU+iP!B05GYQh^29*y{+~Cnyz0#7*oZ=!-
ztM-;Ol87G2H(&zP<4i3o0!0aU5-3UxtgI}vI5R&_4>1=~Q~@ggA*MnkqlCdK_44xc
zU`^M=oT3?^*q;ef%jkBCIk~v(7IT2752V+B1tiQ?keHW}SbU4AAf+fCB*2%LlT@0U
znpXmzxG65m1PO|v3hEV9f-Nfo4Vc{GfjcX;sJIA}r*E-k=B9(@<Uv^&GV#QenRtr_
ztO*u|pkn+MTYeHKJ%LwT-C}o1EJ-X*EdjT}!1cf_W>D1$D&D}g$}J{Pd750Fa*MSf
zv8bf@7E4BcZfX=KI32}<d<U8HE6UF=0mtnv?)=h{g3=Pmq#XyS-JOzIR1A*jTkN1v
zFHSAF#gvj$v<(zcpg!p>#^fk&aMve3IlnZo<Q5BP#JcD;NFFi%%9dD^o|}j~cnZ&0
zw>WHa!RycLK(oEYpfMK?CJsg(Mh+&>GzkwA4<iSo5NKKjHi-hBJ^_jHFbl8>)icyH
sun0&C@C(RtfpvkVVtANAdij`Gm_%4Xd=XC2$O8`}2QwFw7#j~G06clsv;Y7A
literal 0
HcmV?d00001
diff --git a/datasets/__pycache__/data_interface.cpython-39.pyc b/datasets/__pycache__/data_interface.cpython-39.pyc
index 9550db1509e8d477a1c15534a76c6f87976fbec4..59584d31b16ee5ba7d08220a2174fb5e79e9aaed 100644
GIT binary patch
literal 10339
zcmYe~<>g{vU|_gBWkb>#T?U57APzESWnf@%U|?V<wqj&pNMT4}%wdRv(2P-xDGVu0
zIZV0CQOt}GF_tLS6owS$9JXBcD0Z+IOAbdaXA~!x&6>lN%N@m?%M--|=CkGS=JG}H
z<?=`IgZb<^0=a@wf?zgBj!>>}lrSSh3TFyej!3R(lqi_Zog<bj9wiQD^W;e6N=8Y7
z)$`^^<w{3MgV}sJGP$x*vS2oUj$E#MlsuR%kfV^R7^Rr26s44_9Hk5v6U<S`RgF?*
zWN>Fl5lRtmVMq~9WzAA+W{y&KXGjrA5p7{e5rv9qxHF`PrHHpMq=-XBG~F3eBvK?>
z7*ZrbBB_eaEKyqS3@K76(k%=r(y3a_%u(7YjKK_=GA}`)>Zi$gi^IR5Br`uRF-Mc}
z7K=}Qet{<AEq1re5-^kL7Ee)PUP^v$d~rceX2~rsm&B4ppZvs>)FMr$TkIfSacape
zw&2pF;?xpN##^kuexAW0nvA!&N{SNmiqrCoa*JOwGB7Y`GTst!OHC{(ElPDtOotl5
zcuUa3#m&^$(+6Zy4#Zr>TLNAI?&xAH0f{9UnvAzZ97{{`p$55DloTZ<m*f{|GTxH#
z%uPznNz6-5O#zuzoLUl`m|KvOTBON*i`mi1`Ic~EQZZP0a!z7#ac*i!Mt;gIKA0ec
z5AuMwTYgSTGAMMAF(}=E_{9+n3=F9ZptKgnlp@;Bkj9w8n!?t?8O5B!p2E?>5XF)r
z21+9>j8Uv9+$lUQ3{h+;yeWJw3{mVU{3!x03{e~@f+<2R3{jjZ!YLvx3{hO|3@i*$
z+`$Z*;<tD~j`z$fNi9lCOiq0XO6o5;7#J9exEUB2{F2!~Zh&HtFf#)K14yPgjDdln
zgrS*X0pmgjMur;38ishL8m1bCc;*`B8ish58kQP{c-9)$8ishb8nzmSc=j5G1sou?
zC7cVm7BXaUr!Yz~)Uu~A)v(WINMW7}(#4a)63n2<TJ<BkxHLC6v8XbZi%S6t9E;M6
zt>8S6U{Yd9az=b{W>u<!MrK}#jzT^xVOc3er=%7q7iEG4V)YbUQqvMkb4rR8O7a!V
zO!QEU$t%r`FV8Q^PAw`X$s7w*V^R`J5{pwy;`0)7Qx!CdONxkfkGeX-AaJPvV$`p4
z*RRY;%+n8Yc8)J7%FjwoF3~SaPL2mBrOdR<<U~*w(udkvtnZwdo0?OZpJ!->#MDbJ
zF7wmmy~SEwT98_Fiz7ZhGcU6wKK_<KNl{{EUOXcFZV5s}kb>eCe@SX_39=diumG|W
z9;p3bKi^`@E>A2<FTN!J%2x3psrcOdl+v8kB7O!2hFi?Z<tevV3lfV;if?g(WsCCj
zOEfudu@tA~q!r14(wr=akOL9&AVL8|D1tPzWaQ_ju4KH$84n49`1q9!znt_zqL6?B
zMPPA3Vsff}2&fqF_4LtoNlh(qFG@@+K?D=n+Duqt*DI(jl4f9FkOCF&BA`Uh!NtVJ
zh=3f7e2i6MC>cZ#7Kq89G6tjvgh6Z&b_SWm!T_o$f*Dpa`W0z0FfgoSEK&n`1jGgl
z6lsIJ#8Xg|T98<j3Jxg)kVYnuQ7lzbs1_nCO9m?eDPv$@0F@9RyTF!KFfcIGFs3kO
zGZhQeFs3k-G8BcBFw`()F{Uu5u=Fz3GL<luFxN0OGp4YLFf=pPGQ;>a%n~3zO9@L2
zQw?(sOFC0CQw%60L$WSo3R^Z)aSGJT5~!K%ATvRkI)y2iL6f8El!5|0Pw6To=a=S{
zDCCvqCZ!fB<fkcsQ*CN7A}{MIB$a07q$rf-CnqMA<|GzXBFhz|7J-U!1#p_wFH6kP
z2c<j(aD{+kNMb=jP9;bgwA@$FNJ&l0%u7vCNUbQy$xO~H$*EMx%u}e$FD+64#XTfF
zD&!X_Bo-HErh`;w<|%-3g(lqjx0us1OZ-AKnTk|F(ZmXN#q(J+LExqREf$bl+%!3g
zctCLqR>GW_S8|IvCpE7K6eYKK!8*Xv3aYBO%Mx?o93GHv7*CTGoG@;&7MCQZr`}=*
zxe%m-G4U2-5+tR7(+oJduokD5loptQ(uo)-0q_VgvN1CKXJZmz6krx(7hnSMs)SL~
z4p<LHN(Gha;6wv52oj++3|UMy3|Y)6j44dLEWr#bnf=@}nQw6+6%pLXIiI~aqcklo
zCsmWB2xMB3DJU=ySy&gu1-l(gfIR>$bU>a26%eUK;UK?&6EtI$9ICq!AqY`Qju(n-
zL5%>Y`#|A_RK#F&A0I59AXY{pxlRVvb*Ks{cAGgUun>hFHn%~_VnoozA~{YL)p76y
z2Jr@p>p+1GiY{>Ehk%MH#sv&1OexGQ;F4kiQwqyMMi+)=#)XWvj8$?a%q1)fSW{R*
zrO0fCxl9WgYZ#kBB?DVHL!JmD0|=I|r?53MHZevr<S`jAq%Z_CXtMj=V!*H#R8R$j
zTn#E`(-~?QVtHyAYZw<WEM%C-RLBy{08a0qW?XV&L1sx}PG(iACetk@J%d|}nP52x
z0ZvL>HaVHaCCT}@1$Idc3=E$^`Ju`PH5@?R2WK>BC8TGQlb@WJQ*5V)(0q$Az6x9w
z6{nUI>s1+e=9MMpWTq&9Oe#(-QE<!5NmcLz)#v#|3NT4$P-_k>stBsQvhp+Yia-_B
zEmm*=b&EYcwIs2mr05oBW^O@#QAti_(k)&H6Y7*(+@P3*ak#<~b4pWPi;D7#G?~B!
zLKIgrs5%GNst{kpb5W5J0|P@MC>%j~mw{1+k%^Iwk?TJbBh!CS-36l2GZT79pu`O*
zIe_BE8)RkzxJ|c!5mXvvF)w7yVgZ#0txS>(wV=|0bphK#hFa!KhFX>yCKrZSsajC`
zj=7AXNUMe+i?xPTlA(sBhSi3lP^*Sz0ecPOLPkbV!eUBe$YiKtS;E){>Ig7bomBA5
zD=taQOU?vGvqGXmZhlH?jzVd1W?s5NNk*zda(+=!YH>k+UJ6L0Ah9ShH?<_Ss2E;I
zgPMP7Mftf3ptgfv23SS0LP};bs9BU)RH>&BnO~}qoS3JOm{XjukeOFdTB4AcSE&GM
zp`e&jlCO|omReL^lv$FB#qKgt<bdiButks-4q9`~Pm|*oduCp7L27czEiO>(7o~$Z
zMFF6U4QgQAVgXrO1Zl}+r&bo<Vk<33Ni0d#WCKU|Ew22MjMSodkg6yyaBT@;vt{Oi
zlt;0GWDIXHCl;lnC5T(RnaE+81&T%iP<~*QVB}(CVdP-sU=(2FVyuz|C2Dv(Uk^o>
zCTo#6$gL6}0@S((H$(k28T~YQK+Q%_lQus77FRqp&Bn*y;)#zhEKSUT$bef>MS39J
zHXy<gL^y#6P)P$$vqh1h%9Iz{d`nKv$%&6g@`?sXF(_{rfzma|9mODDa4_+Saj<bI
zaxil+bFgr5f@LwfO6(y2f;?V;+*J}zWlm*IWl3d8Wldq7!wT*=v8AxL@J6wvh@`To
zGD|X~aHMd~VFJr=rEs_KMscKartr*R1b3fUqPU<uP$!BdiW|xUb){IMc)(pK-e3kz
z(ObOGuA?ucwF?TsmqrW>3@<_DZOBSiNJD}vCow5CC%z!DB%{a<<XwAEh_EInCTFB3
z!@L0U2`2*s14t*R)m#iJ-w1cJQb64=HgIPP)S+UC=STr{r#MqUoiHv?2_Da#0_ueE
zq_Ed8#Pg<b)G);Jr7+Ym#Pfs7?3J8;nGK8#3@<@;zXT146oJZumnT5(j0X`QPlad-
z6oJ~xkmv?AeTqQNL-Gw*W^Q77D!5B_i!%?>R8K9w#aohCln&||fm?*P_;XVeLA{Uk
z_?*<d^pcEQd}WD6@kxmYKB%dYl$c%|4{iwD;zjUNi%UQqhSZcIP%^m1la>ln18M(h
za)NUcIJnVVR+I_y1V~E}sNMr-#UgN`Llmu`oLU7+MxeTwfsqTTBY@}<urabRu`yPO
zp+q^nQ{cCfsmKBpt!zc9C8b4qDD8DnQ45M^P#3y*611pY$dJVZ?g%r1TG(X_MM)*h
zH4IrSpaK;XhgobXEWIqX%q8q495u}F);kM~U&8`!y|b2Zg4*sitnk)5R}DivcMW3-
zTNy)93zC`aH4Is7DU2x`y`c6zPYG``Ll$2RV+v<4sBq@;184n}OnycBpb!J~OI9+2
z1M?OOsFa2Tq#>wV!{k>4O66celLf4%s1&3L6xLDv;5rVRy20sTB_lXwH2I4_>AFZ6
zq)-J!@PP<KVghL{G6sp6fCzAl8Z@+%nFp@TZ*hUEXb77V-q?YpOHkpn5>ljsEh|z7
z8DR+`tU!b{h(L>Vke)tJtP6qi9wQ$U2csCH0HXk-7!wDh6jPNjYD|M$)GHZ_iWnGB
z8#$ovA1H=FC0;S8C5&iTgF~OO8C0#Zq%eY-*P2Xzn#|yQbBnzgGHO%=Dh@PRA<bz}
zZw@S7BnWacNU8`+yRZmU(9Q!n9#md1FjmQ-Iu@LJz(ECS1EdzA)cha^gF+YVU{I?J
z;b2HlpqDwAVI`BFCKJRZAV<KO`;{P9!3#00P6HMH3z3{AgX%N_ZD@iHE2;w}L1^n9
zlxq==#OfT7FP9)WM;6sN@Pq^jC`}e{HYfrYeV~#sM3d2PC0`M!Q&1EQ@_Y;^mopca
z6oJ#~EtaJG{2WNR2#Sg#P^yANB`AuDK!H&NazYWP<&L%9i(_D507ZK-sPyGvk`a*-
z;A7@vE&_>z$}Ldpgkw-K0^$_otnt{OH6Cb)5V^)<rAmzl9+gF^?utM)Hn_?N@hbu~
zPm@7og&-G#Fvv%sPyuE9;tX)XM`+xT)C#QzRITxWt2KUbwI%?n){0U<T|hxd;Q%V$
zu~rQr7eR_OQ0u8E8x*}cAR-q;<bjBMP;7G~B_?Ml<>#e>tDYiI(68horp9Rj8G*Oj
zSp`a&piTt-YKM&x(!qy>JiOK^Dg+fqpb!RyHVA_XB2XFuS2Z)hAr2a+MXPGkzy%Xh
zRa3)|1sd34tz|CZ097bR)eV#nuWq=|svGVahIk%mb<>4pF1WgZnG0%ql<+k(WbxO4
zDj-mW!wIQyii$yD1gdQ?YLX&w@PTTaqB4*oaGg^QidX^ks)QFLk5-j{b)(iNMWB3E
z1TMqCB?q`o1J1G4AOo;e5c@#!gQy^+;1z@nbCm$7VGWHDaP3f31Bwn%@d}Cx5C*Z~
zB^t;WL{SUsV1Wu+P0)Y~V-aZhU6UD7a)4b4E@W{P{bxbW0>?YFeGhd8!Ezte;{z2(
z;6^sGV|tk&Wj-UgIEA?c98@@60xH}tAh|>mDL@Dm^LX7+)CP)VXkdUN3a1l5ez}a~
z1n>|lEHFUB*r3J;v}k9A6zxbQmL{WLQ3ELE!EHe>0Zs$p{zx+e1H%fC(V$WcTrTs2
z%Vm%_M!5{~4=4(YL2Vb<q)jSwD!3m#hY2(k+rrY!7{v+|W1Yj40_u~1#Mr>oWE{7I
zof3;v!E-{0aP(Wrbc++z2u{n-NhyXjKEUw~N?p+I0w|Gz!l4-C2yp4A1M|~L$hbmL
zCrkiR1A+W^iyNF~<G}{s0QnPaE>o2}7H8-osa?r*iwmSFJ~J;RwSr_PL5Dhuda*c3
z0AUTp^gBrIQo-*ohz3n|w5qxY)W88()kWPP8$m5TNFact8%%(Lqo|#Mf#EvHI#311
z!N|lY#LUD97R88cPyq`H|Kc9xZl@$tw-db*mV$OW!If|dsLRP3#gWRA!kWS}hY8y0
z<OK8i=P;!RqzJZvI-%StLUS0wGj}XeJgE?#Xo^@1OB64ZC!QkF!V<*??vnBcGiXZR
z5_2xfFD?#C%!x;7=U^m#0+lr=S5UwAiLJ7(VTcz@0rf|Pz@#vk6agn=a3wBEok|>1
znSi1aTz^Aq0#Mm<3);s3CFxrnpkfQeDryBK7I0yJNFtzy#3N7=0f`f;u-Vukea9+!
z>@k9;5(njDP(=>T)NUA+I8r4G8k}QGVFvXR!IQ7-NOGX+9V}M^8WjMI&$2?RTu^OW
zR08rKs7lr3M6XTxKvHP6DY#Aog_kA^q~!zZ?ILxvz@Y>tz_|%jsr~~66(~0`F!C{r
zv9K{$$>0bIaJDb%0M(x0LI{+(poJSe96*Ci;K^nNP|exGPy!m0W?aAo8mk8NnVT6w
zqt~D@Y|z9rDE(+M`=RuEKt_YwMc~nB&>&MfXf&DwG#ZUG5)CdKG?|J(;a=ne3JI{U
zzy#Q99t;c&{EQ3?MV<@{3{__MyaXv+u?;|LvO~&bQIG{7TW+xyCFZ547J=#yaN@kh
z1|C%|F6sg)0b36yz-|B)WxR|G4B#T4fsu`|N{O&Lia^B`%D@CDp@BjKoY24l%#g)U
z!zjrR!&J*u%UlBLxH8o+frtKKJZ3PD1r)$Vpn67=wP-TPg`kKAJG^KDh&vHfg+r4H
zsHlfcj%DVhWG1H;M}a5Dpd4_JLkO_%KxMQb$akQ~XJF(5O~$Gd^dF=Gz}^M}83gtu
zG5vF<qFRs-z};SGw;b#xa3h3pw_F_La!^`hV64LBYJ%-Fc%u<ieIeRui2gaa8w~4}
zg901eZUc`cVC@Ej+G^5BE>p!3o}dB{(Pl!)KcIF7aqcRr2L&Lu#w#ci<dB@9hQk@?
z-EypLH%$Tb%Dxd~FSyO$1Y&^->>^N9K)U6iTvG%}T##lqu^se&1_p))pt1~9E^;u*
z$cXU^un4g6vG6e$f#fuGih@A4`GNY1Y$cTinR)48TW<*#RF>oyC1=FvWTt17<Ynfi
z-(o7riITw(*2~GyFDTYaN-Rz-0%gc31zbwfGE3q?ii<#bH%b6(yk2QZW=^plDB8di
z(O~aINubMt6%?nI6oJ|=MWFl|B?4AcmRX#cp9hmJE&`1OM{y%b6@i+-Q38<3DUdag
z7BzS#JW4RRw74Wc7iqMq2-F>m5`~Iq6{I7}fdcUsJ7_RGCo>6L^F;B$7jUH(6&Ha@
z<|rPp?Zwa$GEk<w#hjRwTm<Sv6$yhv8<a6_i5F+*q$U>S>48QlbK;9rb5g+*K}Dc)
zHA>hA7VL-yHFzQcJhoB<YRH2dc;Hxslt17i7EFLjqFWp`kc9(wptN5MD(*njDm;uF
dj694yOrXI+9uNzGnM4?QScDRU421MWI01I;6;=QM
delta 2455
zcmaDH@YqZ+k(ZZ?fq{YHhHOVtyaEHm<H>e1l1zsfCh8c}=W;}GFfybtq_E_0=5j@G
zf!VA%+_^kaJd6w}Y$@zHyt#Z)d|)<54u7sdlmJ)_XO3X5P?Qju&6OjZD-tCFW^?C=
z=88p$F*3L_r0}HhwlJjdrm|*<H#0{`xHF{irSP{fr0_#UB;7R`QUp>2TNqLVp+ZtA
zjKK_=LN7t?_0wd$#o=F2l9`{Em@~PbQFiivM)%25OqPtv69q*k&t<A*WZx{!T*}DE
zK6xU`6-KVfX{^dp!VC-yMIs<V6hw%D2yqZ0F?lg-EUOR$1B1Y1R<<x*O-8>WMFs|j
zm5fEQAbAiQEKsDxz`$^er=TdcAh9SlJ|(dvQEhVv+Z0AdnaQ#ovZ5Ry)iNM8%$a#5
zx0rKM^NN@!`*RpD%1&<JkP<^Ph_yJiq_jY1@&=AcRxBW`+8_eqOl1%kO^-lHQDSBu
z*y^18#FW$`-^n$c@%F|H3=EnqMPO5OK-%z|!dI4<gJed)WGSv#2V^sJL1rL=37bp!
zOHzwVkPHc)yq)VZqvGUlZfR*WM|0WaWEPhs=jRsKMNB@(-O4CF*`3Ez6%@EpT*-;K
zsX6g^iMgqeFb3-d1$2=l0|P_&=B+$kpm=oPlhJYmxe63dw^$O3(u<2EKtautomyFZ
zi><UEC9xz`lMNhYQIp&FbmYJ$feDatw|F!2ic1pnk~0%?GOJP(CSTyQ5a0uu$PXge
z!Ng<<e!VP^4Mi#-z7dEp2N4z^!V*M)Tvg-;vY9tNJ~=0`xHvgACnr80$@TIebs#&7
zco`TNKz0-hFfcH1F!70TuyHAJFmo_-uyAk|@lM{rZ#MacSe!GjucwboVo9QJeoARh
zY7y8mAZNW~Vqjnh@mtAS1Ws05If+TBIq?OFB^gB~AWKX^?qE$$OwLH194_AGv69m-
zvw@L;;Uy;n1H(&D<}L~Wu}**pux~&{hpZGV5(H@n`xInyQ5c8|H!^CnpP*QM6mLml
zQF>}gd|qj8Qfg5Ye{O1GUVKtwdVEf5UV2GJ6kl0l5eh%AG#A7#jt8gGC|(3VwYcOK
zOL1yy$}Nth#N_Ox{Jhkna0Uj3D4w)bkm|hRwEUvnm0aNH1ADf}5!FpaaUdsx^b{q4
zSY99klmd#t!3~e+B9NyF7#J9e6(-v%D6>|H`Fi?H-Y6l=vXZGtZ*s7Z>g2}~bLv-e
z`hkO4lgY106=akeBnyEoy~UE7pOOl(5Nww^NRG*`NCTu0L})S=>48cJfg*4iOM%3t
zK?Eq=A<hRGgk*#!NLC9(FoWW3@={^taBv)eVr3;H*@De1k^?Ey2N4D!!VpBDc@d<i
z1>{ALUl|zrm^c`v7zG#w7{!=47-c5Qi%2piPqq}1oqT{(%0ZLaFJvWC5h$q?nSmS)
zEho}J2E(Jg2umI>0ww)PlMjm|S;Ez=gzC;9tQ(YPrcRC%jW)?cvNn^jMo>PRG5M{i
zRy`{?j*3880i2LS{4^Q;R`M13gA54(5rLr4WG*f#DguS@EtaJG{2Xu+C<0}jB2Y*{
zibarLia<^(0vTTv1kwynG++W8z~C$!%)r0^a&Ivg0|PGu2a}A56dxZmA9GPKG`D66
zf$|Wtn~FjwGs??OE)>tzhyWQI3nCIh1lU-x1z@v5(KQbgZ{VakdA)=fOA-SE!{j(2
zDNv%cD@p-rN&^vDAOcN0$TOQj+Ck=EL_5=DM=4>(Y*@T&Xfpe`73F}8$psO(;`hMh
zeNq;a=cy<ugM3|-k4r5m=nhZzkTyf81?4iV=>`<S$0nba)(k=qDp1{82(khkC}0BY
zNpSF#FfcI80r?M<sgOb*B(5o4<N?y>GC5dHLqeb=zbH9Fue2mHr&teE4;O)QRT0Z%
zMs+D+R#3UZ4OUT>S)7@lSHw11PhHFxR6R!tCYKhM<mVz)Tt%Qfdy5@ZsODrQfirXz
zPf1ZCq)19FDlXy%8OQ@RusAWdASbn`h-dPBb;&?b@m>T<Fhw9IBGmw3kAeJji^C=t
fT&CMGf;<K)YdIJ>7<m{u7zG%4kRg*$f{-2nz3A|R
diff --git a/datasets/custom_dataloader.py b/datasets/custom_dataloader.py
index 02850f5..cf65534 100644
--- a/datasets/custom_dataloader.py
+++ b/datasets/custom_dataloader.py
@@ -41,7 +41,7 @@ class HDF5MILDataloader(data.Dataset):
data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
"""
- def __init__(self, file_path, label_path, mode, n_classes, backbone=None, load_data=False, data_cache_size=20, bag_size=1024):
+ def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=10, bag_size=1024):
super().__init__()
self.data_info = []
@@ -55,7 +55,6 @@ class HDF5MILDataloader(data.Dataset):
self.label_path = label_path
self.n_classes = n_classes
self.bag_size = bag_size
- self.backbone = backbone
# self.label_file = label_path
recursive = True
@@ -134,10 +133,6 @@ class HDF5MILDataloader(data.Dataset):
RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)),
transforms.ToTensor()
])
- if self.backbone == 'dino':
- self.feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
-
- # self._add_data_infos(load_data)
def __getitem__(self, index):
# get data
@@ -150,9 +145,6 @@ class HDF5MILDataloader(data.Dataset):
# print(img.shape)
for img in batch: # expects numpy
img = img.numpy().astype(np.uint8)
- if self.backbone == 'dino':
- img = self.feature_extractor(images=img, return_tensors='pt')
- # img = self.resize_transforms(img)
img = seq_img_d.augment_image(img)
img = self.val_transforms(img)
out_batch.append(img)
@@ -160,23 +152,24 @@ class HDF5MILDataloader(data.Dataset):
else:
for img in batch:
img = img.numpy().astype(np.uint8)
- if self.backbone == 'dino':
- img = self.feature_extractor(images=img, return_tensors='pt')
- img = self.resize_transforms(img)
-
img = self.val_transforms(img)
out_batch.append(img)
- if len(out_batch) == 0:
- # print(name)
- out_batch = torch.randn(self.bag_size,3,256,256)
- else: out_batch = torch.stack(out_batch)
+ # if len(out_batch) == 0:
+ # # print(name)
+ # out_batch = torch.randn(self.bag_size,3,256,256)
+ # else:
+ out_batch = torch.stack(out_batch)
# print(out_batch.shape)
# out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch
-
+ # print(out_batch.shape)
+ if out_batch.shape != torch.Size([self.bag_size, 256, 256, 3]) and out_batch.shape != torch.Size([self.bag_size, 3,256,256]):
+ print(name)
+ print(out_batch.shape)
+ out_batch = torch.permute(out_batch, (0, 2,1,3))
label = torch.as_tensor(label)
label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
- return out_batch, label, name
+ return out_batch, label, name #, name_batch
def __len__(self):
return len(self.data_info)
@@ -184,7 +177,9 @@ class HDF5MILDataloader(data.Dataset):
def _add_data_infos(self, file_path, load_data):
wsi_name = Path(file_path).stem
if wsi_name in self.slideLabelDict:
+ # if wsi_name[:2] != 'RU': #skip RU because of container problems in dataset
label = self.slideLabelDict[wsi_name]
+ # print(wsi_name)
idx = -1
self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx})
@@ -195,14 +190,29 @@ class HDF5MILDataloader(data.Dataset):
"""
with h5py.File(file_path, 'r') as h5_file:
wsi_batch = []
+ tile_names = []
for tile in h5_file.keys():
+
img = h5_file[tile][:]
img = img.astype(np.uint8)
img = torch.from_numpy(img)
# img = self.resize_transforms(img)
+
wsi_batch.append(img)
- wsi_batch = torch.stack(wsi_batch)
- wsi_batch, _ = to_fixed_size_bag(wsi_batch, self.bag_size)
+ tile_names.append(tile)
+
+ # print('Empty Container: ', file_path) #/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/RU0248_PASD_jke_PASD_20200201_195900_BIG.hdf5
+
+ if wsi_batch:
+ wsi_batch = torch.stack(wsi_batch)
+ else:
+ print('Empty Container: ', file_path)
+ wsi_batch = torch.randn(self.bag_size,3,256,256)
+
+ if wsi_batch.shape[1:] != torch.Size([3, 256, 256]) and wsi_batch.shape[1:] != torch.Size([256, 256, 3]):
+ print(file_path)
+ print(wsi_batch.shape)
+ wsi_batch, name_batch = to_fixed_size_bag(wsi_batch, self.bag_size)
idx = self._add_to_cache(wsi_batch, file_path)
file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path)
self.data_info[file_idx + idx]['cache_idx'] = idx
@@ -461,16 +471,19 @@ class RandomHueSaturationValue(object):
img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
return img #, lbl
-def to_fixed_size_bag(bag: torch.Tensor, bag_size: int = 512):
+def to_fixed_size_bag(bag, bag_size: int = 512):
# get up to bag_size elements
bag_idxs = torch.randperm(bag.shape[0])[:bag_size]
bag_samples = bag[bag_idxs]
+ # bag_sample_names = [bag_names[i] for i in bag_idxs]
# zero-pad if we don't have enough samples
zero_padded = torch.cat((bag_samples,
torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+ # zero_padded_names = bag_sample_names + ['']*(bag_size - len(bag_sample_names))
return zero_padded, min(bag_size, len(bag))
+ # return zero_padded, zero_padded_names, min(bag_size, len(bag))
class RandomHueSaturationValue(object):
@@ -510,11 +523,11 @@ if __name__ == '__main__':
train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
# label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
- label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+ label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_no_other.json'
output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/augments'
os.makedirs(output_path, exist_ok=True)
- n_classes = 2
+ n_classes = 5
dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes, bag_size=20)
# print(dataset.dataset)
@@ -528,15 +541,19 @@ if __name__ == '__main__':
# print(len(dataset))
# # x = 0
+ #/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/RU0248_PASD_jke_PASD_20200201_195900_BIG.hdf5
c = 0
label_count = [0] *n_classes
for item in dl:
- # if c >=10:
- # break
+ if c >=10:
+ break
bag, label, name = item
- label_count[np.argmax(label)] += 1
- print(label_count)
- print(len(train_ds))
+ print(name)
+ # if name == 'RU0248_PASD_jke_PASD_20200201_195900_BIG':
+
+ # print(bag)
+ # print(label)
+ c += 1
# # # print(bag.shape)
# # if bag.shape[1] == 1:
# # print(name)
@@ -578,7 +595,7 @@ if __name__ == '__main__':
# o_img = Image.fromarray(o_img)
# o_img = o_img.convert('RGB')
# o_img.save(f'{output_path}/{i}_original.png')
- # c += 1
+
# break
# else: break
# print(data.shape)
diff --git a/datasets/custom_jpg_dataloader.py b/datasets/custom_jpg_dataloader.py
new file mode 100644
index 0000000..722b275
--- /dev/null
+++ b/datasets/custom_jpg_dataloader.py
@@ -0,0 +1,459 @@
+'''
+ToDo: remove bag_size
+'''
+
+
+import numpy as np
+from pathlib import Path
+import torch
+from torch.utils import data
+from torch.utils.data.dataloader import DataLoader
+from tqdm import tqdm
+import torchvision.transforms as transforms
+from PIL import Image
+import cv2
+import json
+import albumentations as A
+from imgaug import augmenters as iaa
+import imgaug as ia
+from torchsampler import ImbalancedDatasetSampler
+
+
+class RangeNormalization(object):
+ def __call__(self, sample):
+ img = sample
+ return (img / 255.0 - 0.5) / 0.5
+
+class JPGMILDataloader(data.Dataset):
+ """Represents an abstract HDF5 dataset. For single H5 container!
+
+ Input params:
+ file_path: Path to the folder containing the dataset (one or multiple HDF5 files).
+ mode: 'train' or 'test'
+ load_data: If True, loads all the data immediately into RAM. Use this if
+ the dataset is fits into memory. Otherwise, leave this at false and
+ the data will load lazily.
+ data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
+
+ """
+ def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=10, bag_size=1024):
+ super().__init__()
+
+ self.data_info = []
+ self.data_cache = {}
+ self.slideLabelDict = {}
+ self.files = []
+ self.data_cache_size = data_cache_size
+ self.mode = mode
+ self.file_path = file_path
+ # self.csv_path = csv_path
+ self.label_path = label_path
+ self.n_classes = n_classes
+ self.bag_size = bag_size
+ self.empty_slides = []
+ # self.label_file = label_path
+ recursive = True
+
+ # read labels and slide_path from csv
+ with open(self.label_path, 'r') as f:
+ temp_slide_label_dict = json.load(f)[mode]
+ for (x, y) in temp_slide_label_dict:
+ x = Path(x).stem
+
+ # x_complete_path = Path(self.file_path)/Path(x)
+ for cohort in Path(self.file_path).iterdir():
+ x_complete_path = Path(self.file_path) / cohort / 'BLOCKS' / Path(x)
+ if x_complete_path.is_dir():
+ if len(list(x_complete_path.iterdir())) > 50:
+ # print(x_complete_path)
+ self.slideLabelDict[x] = y
+ self.files.append(x_complete_path)
+ else: self.empty_slides.append(x_complete_path)
+ # print(len(self.empty_slides))
+ # print(self.empty_slides)
+
+
+ for slide_dir in tqdm(self.files):
+ self._add_data_infos(str(slide_dir.resolve()), load_data)
+
+
+ self.resize_transforms = A.Compose([
+ A.SmallestMaxSize(max_size=256)
+ ])
+ sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1")
+ sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2")
+ sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3")
+ sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4")
+ sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5")
+
+ self.train_transforms = iaa.Sequential([
+ iaa.AddToHueAndSaturation(value=(-13, 13), name="MyHSV"),
+ sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")),
+ iaa.Fliplr(0.5, name="MyFlipLR"),
+ iaa.Flipud(0.5, name="MyFlipUD"),
+ sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")),
+ iaa.OneOf([
+ sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")),
+ sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")),
+ sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine"))
+ ], name="MyOneOf")
+
+ ], name="MyAug")
+
+ # self.train_transforms = A.Compose([
+ # A.HueSaturationValue(hue_shift_limit=13, sat_shift_limit=2, val_shift_limit=0, always_apply=True, p=1.0),
+ # # A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=0, val_shift_limit=0, always_apply=False, p=0.5),
+ # # A.RandomGamma(),
+ # # A.HorizontalFlip(),
+ # # A.VerticalFlip(),
+ # # A.RandomRotate90(),
+ # # A.OneOf([
+ # # A.ElasticTransform(alpha=150, sigma=20, alpha_affine=50),
+ # # A.Affine(
+ # # scale={'x': (0.95, 1.05), 'y': (0.95, 1.05)},
+ # # rotate=(-45, 45),
+ # # shear=(-4, 4),
+ # # cval=8,
+ # # )
+ # # ]),
+ # A.Normalize(),
+ # ToTensorV2(),
+ # ])
+ self.val_transforms = transforms.Compose([
+ # A.Normalize(),
+ # ToTensorV2(),
+ RangeNormalization(),
+ transforms.ToTensor(),
+
+ ])
+ self.img_transforms = transforms.Compose([
+ transforms.RandomHorizontalFlip(p=1),
+ transforms.RandomVerticalFlip(p=1),
+ # histoTransforms.AutoRandomRotation(),
+ transforms.Lambda(lambda a: np.array(a)),
+ ])
+ self.hsv_transforms = transforms.Compose([
+ RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)),
+ transforms.ToTensor()
+ ])
+
+ def __getitem__(self, index):
+ # get data
+ batch, label, name = self.get_data(index)
+ out_batch = []
+ seq_img_d = self.train_transforms.to_deterministic()
+
+ if self.mode == 'train':
+ # print(img)
+ # print(.shape)
+ for img in batch: # expects numpy
+ img = img.numpy().astype(np.uint8)
+ # print(img.shape)
+ img = seq_img_d.augment_image(img)
+ img = self.val_transforms(img)
+ out_batch.append(img)
+
+ else:
+ for img in batch:
+ img = img.numpy().astype(np.uint8)
+ img = self.val_transforms(img)
+ out_batch.append(img)
+
+ # if len(out_batch) == 0:
+ # # print(name)
+ # out_batch = torch.randn(self.bag_size,3,256,256)
+ # else:
+ out_batch = torch.stack(out_batch)
+ # print(out_batch.shape)
+ # out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch
+ # print(out_batch.shape)
+ # if out_batch.shape != torch.Size([self.bag_size, 256, 256, 3]) and out_batch.shape != torch.Size([self.bag_size, 3,256,256]):
+ # print(name)
+ # print(out_batch.shape)
+ # out_batch = torch.permute(out_batch, (0, 2,1,3))
+ label = torch.as_tensor(label)
+ label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
+ # print(out_batch)
+ return out_batch, label, name #, name_batch
+
+ def __len__(self):
+ return len(self.data_info)
+
+ def _add_data_infos(self, file_path, load_data):
+ wsi_name = Path(file_path).stem
+ if wsi_name in self.slideLabelDict:
+ # if wsi_name[:2] != 'RU': #skip RU because of container problems in dataset
+ label = self.slideLabelDict[wsi_name]
+ # print(wsi_name)
+ idx = -1
+ self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx})
+
+ def _load_data(self, file_path):
+ """Load data to the cache given the file
+ path and update the cache index in the
+ data_info structure.
+ """
+ wsi_batch = []
+ tile_names = []
+ # print(wsi_batch)
+ # for tile_path in Path(file_path).iterdir():
+ # print(tile_path)
+ for tile_path in Path(file_path).iterdir():
+ # print(tile_path)
+ img = np.asarray(Image.open(tile_path)).astype(np.uint8)
+ img = torch.from_numpy(img)
+
+ # print(wsi_batch)
+ wsi_batch.append(img)
+
+ tile_names.append(tile_path.stem)
+
+ # if wsi_batch:
+ wsi_batch = torch.stack(wsi_batch)
+ if len(wsi_batch.shape) < 4:
+ wsi_batch.unsqueeze(0)
+ # else:
+ # print('Empty Container: ', file_path)
+ # self.empty_slides.append(file_path)
+ # wsi_batch = torch.randn(self.bag_size,256,256,3)
+ # print(wsi_batch.shape)
+ # if wsi_batch.shape[1:] != torch.Size([3, 256, 256]) and wsi_batch.shape[1:] != torch.Size([256, 256, 3]):
+ # print(file_path)
+ # print(wsi_batch.shape)
+ # wsi_batch, name_batch = to_fixed_size_bag(wsi_batch, self.bag_size)
+ idx = self._add_to_cache(wsi_batch, file_path)
+ file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path)
+ self.data_info[file_idx + idx]['cache_idx'] = idx
+
+ # remove an element from data cache if size was exceeded
+ if len(self.data_cache) > self.data_cache_size:
+ # remove one item from the cache at random
+ removal_keys = list(self.data_cache)
+ removal_keys.remove(file_path)
+ self.data_cache.pop(removal_keys[0])
+ # remove invalid cache_idx
+ # self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'shape': di['shape'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
+ self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
+
+ def _add_to_cache(self, data, data_path):
+ """Adds data to the cache and returns its index. There is one cache
+ list for every file_path, containing all datasets in that file.
+ """
+ if data_path not in self.data_cache:
+ self.data_cache[data_path] = [data]
+ else:
+ self.data_cache[data_path].append(data)
+ return len(self.data_cache[data_path]) - 1
+
+ # def get_data_infos(self, type):
+ # """Get data infos belonging to a certain type of data.
+ # """
+ # data_info_type = [di for di in self.data_info if di['type'] == type]
+ # return data_info_type
+
+ def get_name(self, i):
+ # name = self.get_data_infos(type)[i]['name']
+ name = self.data_info[i]['name']
+ return name
+
+ def get_labels(self, indices):
+
+ return [self.data_info[i]['label'] for i in indices]
+ # return self.slideLabelDict.values()
+
+ def get_data(self, i):
+ """Call this function anytime you want to access a chunk of data from the
+ dataset. This will make sure that the data is loaded in case it is
+ not part of the data cache.
+ i = index
+ """
+ # fp = self.get_data_infos(type)[i]['data_path']
+ fp = self.data_info[i]['data_path']
+ if fp not in self.data_cache:
+ self._load_data(fp)
+
+ # get new cache_idx assigned by _load_data_info
+ # cache_idx = self.get_data_infos(type)[i]['cache_idx']
+ cache_idx = self.data_info[i]['cache_idx']
+ label = self.data_info[i]['label']
+ name = self.data_info[i]['name']
+ # print(self.data_cache[fp][cache_idx])
+ return self.data_cache[fp][cache_idx], label, name
+
+
+
+class RandomHueSaturationValue(object):
+
+ def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5):
+
+ self.hue_shift_limit = hue_shift_limit
+ self.sat_shift_limit = sat_shift_limit
+ self.val_shift_limit = val_shift_limit
+ self.p = p
+
+ def __call__(self, sample):
+
+ img = sample #,lbl
+
+ if np.random.random() < self.p:
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32
+ h, s, v = cv2.split(img)
+ hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1)
+ hue_shift = np.uint8(hue_shift)
+ h += hue_shift
+ sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1])
+ s = cv2.add(s, sat_shift)
+ val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1])
+ v = cv2.add(v, val_shift)
+ img = cv2.merge((h, s, v))
+ img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
+ return img #, lbl
+
+def to_fixed_size_bag(bag, bag_size: int = 512):
+
+ #duplicate bag instances unitl
+
+ # get up to bag_size elements
+ bag_idxs = torch.randperm(bag.shape[0])[:bag_size]
+ bag_samples = bag[bag_idxs]
+ # bag_sample_names = [bag_names[i] for i in bag_idxs]
+ q, r = divmod(bag_size, bag_samples.shape[0])
+ if q > 0:
+ bag_samples = torch.cat([bag_samples]*q, 0)
+
+ self_padded = torch.cat([bag_samples, bag_samples[:r,:, :, :]])
+
+ # zero-pad if we don't have enough samples
+ # zero_padded = torch.cat((bag_samples,
+ # torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+
+ return self_padded, min(bag_size, len(bag))
+
+
+class RandomHueSaturationValue(object):
+
+ def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5):
+
+ self.hue_shift_limit = hue_shift_limit
+ self.sat_shift_limit = sat_shift_limit
+ self.val_shift_limit = val_shift_limit
+ self.p = p
+
+ def __call__(self, sample):
+
+ img = sample #,lbl
+
+ if np.random.random() < self.p:
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32
+ h, s, v = cv2.split(img)
+ hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1)
+ hue_shift = np.uint8(hue_shift)
+ h += hue_shift
+ sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1])
+ s = cv2.add(s, sat_shift)
+ val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1])
+ v = cv2.add(v, val_shift)
+ img = cv2.merge((h, s, v))
+ img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
+ return img #, lbl
+
+
+
+if __name__ == '__main__':
+ from pathlib import Path
+ import os
+
+ home = Path.cwd().parts[1]
+ train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
+ data_root = f'/{home}/ylan/data/DeepGraft/256_256um'
+ # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+ # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
+ label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+ output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/augments'
+ os.makedirs(output_path, exist_ok=True)
+
+ n_classes = 2
+
+ dataset = JPGMILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes, bag_size=20)
+ # print(dataset.dataset)
+ # a = int(len(dataset)* 0.8)
+ # b = int(len(dataset) - a)
+ # train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b])
+ dl = DataLoader(dataset, None, num_workers=1)
+ print(len(dl))
+ dl = DataLoader(dataset, None, sampler=ImbalancedDatasetSampler(dataset), num_workers=5)
+
+
+
+ # data = DataLoader(dataset, batch_size=1)
+
+ # print(len(dataset))
+ # # x = 0
+ #/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/RU0248_PASD_jke_PASD_20200201_195900_BIG.hdf5
+ c = 0
+ label_count = [0] *n_classes
+ print(len(dl))
+ for item in dl:
+ # if c >=10:
+ # break
+ bag, label, name = item
+ # print(label)
+ label_count[torch.argmax(label)] += 1
+ # print(name)
+ # if name == 'RU0248_PASD_jke_PASD_20200201_195900_BIG':
+
+ # print(bag)
+ # print(label)
+ c += 1
+ print(label_count)
+ # # # print(bag.shape)
+ # # if bag.shape[1] == 1:
+ # # print(name)
+ # # print(bag.shape)
+ # print(bag.shape)
+
+ # # out_dir = Path(output_path) / name
+ # # os.makedirs(out_dir, exist_ok=True)
+
+ # # # print(item[2])
+ # # # print(len(item))
+ # # # print(item[1])
+ # # # print(data.shape)
+ # # # data = data.squeeze()
+ # # bag = item[0]
+ # bag = bag.squeeze()
+ # original = original.squeeze()
+ # for i in range(bag.shape[0]):
+ # img = bag[i, :, :, :]
+ # img = img.squeeze()
+
+ # img = ((img-img.min())/(img.max() - img.min())) * 255
+ # print(img)
+ # # print(img)
+ # img = img.numpy().astype(np.uint8).transpose(1,2,0)
+
+
+ # img = Image.fromarray(img)
+ # img = img.convert('RGB')
+ # img.save(f'{output_path}/{i}.png')
+
+
+
+ # o_img = original[i,:,:,:]
+ # o_img = o_img.squeeze()
+ # print(o_img.shape)
+ # o_img = ((o_img-o_img.min())/(o_img.max()-o_img.min()))*255
+ # o_img = o_img.numpy().astype(np.uint8).transpose(1,2,0)
+ # o_img = Image.fromarray(o_img)
+ # o_img = o_img.convert('RGB')
+ # o_img.save(f'{output_path}/{i}_original.png')
+
+ # break
+ # else: break
+ # print(data.shape)
+ # print(label)
+ # a = [torch.Tensor((3,256,256))]*3
+ # b = torch.stack(a)
+ # print(b)
+ # c = to_fixed_size_bag(b, 512)
+ # print(c)
\ No newline at end of file
diff --git a/datasets/data_interface.py b/datasets/data_interface.py
index 056e6ff..efa104c 100644
--- a/datasets/data_interface.py
+++ b/datasets/data_interface.py
@@ -2,15 +2,25 @@ import inspect # 查看python 类的参数和模块、函数代码
import importlib # In order to dynamically import the library
from typing import Optional
import pytorch_lightning as pl
+from pytorch_lightning.loops.base import Loop
+from pytorch_lightning.loops.fit_loop import FitLoop
+
from torch.utils.data import random_split, DataLoader
+from torch.utils.data.dataset import Dataset, Subset
from torchvision.datasets import MNIST
from torchvision import transforms
from .camel_dataloader import FeatureBagLoader
from .custom_dataloader import HDF5MILDataloader
+from .custom_jpg_dataloader import JPGMILDataloader
from pathlib import Path
from transformers import AutoFeatureExtractor
from torchsampler import ImbalancedDatasetSampler
+from abc import ABC, abstractclassmethod, abstractmethod
+from sklearn.model_selection import KFold
+
+
+
class DataInterface(pl.LightningDataModule):
def __init__(self, train_batch_size=64, train_num_workers=8, test_batch_size=1, test_num_workers=1,dataset_name=None, **kwargs):
@@ -109,7 +119,7 @@ class DataInterface(pl.LightningDataModule):
class MILDataModule(pl.LightningDataModule):
- def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, backbone=None, *args, **kwargs):
+ def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=50, n_classes=2, cache: bool=True, *args, **kwargs):
super().__init__()
self.data_root = data_root
self.label_path = label_path
@@ -124,27 +134,29 @@ class MILDataModule(pl.LightningDataModule):
self.num_bags_test = 50
self.seed = 1
- self.backbone = backbone
self.cache = True
self.fe_transform = None
+
def setup(self, stage: Optional[str] = None) -> None:
home = Path.cwd().parts[1]
if stage in (None, 'fit'):
- dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes, backbone=self.backbone)
+ dataset = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes)
a = int(len(dataset)* 0.8)
b = int(len(dataset) - a)
self.train_data, self.valid_data = random_split(dataset, [a, b])
if stage in (None, 'test'):
- self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, backbone=self.backbone)
+ self.test_data = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, data_cache_size=1)
return super().setup(stage=stage)
+
+
def train_dataloader(self) -> DataLoader:
- return DataLoader(self.train_data, self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True,
+ return DataLoader(self.train_data, batch_size = self.batch_size, sampler=ImbalancedDatasetSampler(self.train_data), num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True,
#sampler=ImbalancedDatasetSampler(self.train_data)
def val_dataloader(self) -> DataLoader:
return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers)
@@ -187,13 +199,92 @@ class DataModule(pl.LightningDataModule):
if stage in (None, 'test'):
self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, backbone=self.backbone)
+
return super().setup(stage=stage)
def train_dataloader(self) -> DataLoader:
- return DataLoader(self.train_data, self.batch_size, num_workers=self.num_workers, shuffle=False,) #batch_transforms=self.transform, pseudo_batch_dim=True,
+ return DataLoader(self.train_data, self.batch_size, shuffle=False,) #batch_transforms=self.transform, pseudo_batch_dim=True,
#sampler=ImbalancedDatasetSampler(self.train_data),
def val_dataloader(self) -> DataLoader:
- return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers)
+ return DataLoader(self.valid_data, batch_size = self.batch_size)
+
+ def test_dataloader(self) -> DataLoader:
+ return DataLoader(self.test_data, batch_size = self.batch_size) #, num_workers=self.num_workers
+
+
+class BaseKFoldDataModule(pl.LightningDataModule, ABC):
+ @abstractmethod
+ def setup_folds(self, num_folds: int) -> None:
+ pass
+
+ @abstractmethod
+ def setup_fold_index(self, fold_index: int) -> None:
+ pass
+
+class CrossVal_MILDataModule(BaseKFoldDataModule):
+
+ def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, backbone=None, *args, **kwargs):
+ super().__init__()
+ self.data_root = data_root
+ self.label_path = label_path
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.image_size = 384
+ self.n_classes = n_classes
+ self.target_number = 9
+ self.mean_bag_length = 10
+ self.var_bag_length = 2
+ self.num_bags_train = 200
+ self.num_bags_test = 50
+ self.seed = 1
+
+ self.backbone = backbone
+ self.cache = True
+ self.fe_transform = None
+
+ # train_dataset: Optional[Dataset] = None
+ # test_dataset: Optional[Dataset] = None
+ # train_fold: Optional[Dataset] = None
+ # val_fold: Optional[Dataset] = None
+ self.train_data : Optional[Dataset] = None
+ self.test_data : Optional[Dataset] = None
+ self.train_fold : Optional[Dataset] = None
+ self.val_fold : Optional[Dataset] = None
+
+ def setup(self, stage: Optional[str] = None) -> None:
+ home = Path.cwd().parts[1]
+
+ # if stage in (None, 'fit'):
+ dataset = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes)
+ # a = int(len(dataset)* 0.8)
+ # b = int(len(dataset) - a)
+ # self.train_data, self.val_data = random_split(dataset, [a, b])
+ self.train_data = dataset
+
+ # if stage in (None, 'test'):,
+ self.test_data = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes)
+
+ # return super().setup(stage=stage)
+
+ def setup_folds(self, num_folds: int) -> None:
+ self.num_folds = num_folds
+ self.splits = [split for split in KFold(num_folds).split(range(len(self.train_data)))]
+
+ def setup_fold_index(self, fold_index: int) -> None:
+ train_indices, val_indices = self.splits[fold_index]
+ self.train_fold = Subset(self.train_data, train_indices)
+ self.val_fold = Subset(self.train_data, val_indices)
+
+
+ def train_dataloader(self) -> DataLoader:
+ return DataLoader(self.train_fold, self.batch_size, sampler=ImbalancedDatasetSampler(self.train_fold), num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True,
+ # return DataLoader(self.train_fold, self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True,
+ #sampler=ImbalancedDatasetSampler(self.train_data)
+ def val_dataloader(self) -> DataLoader:
+ return DataLoader(self.val_fold, batch_size = self.batch_size, num_workers=self.num_workers)
def test_dataloader(self) -> DataLoader:
- return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers)
\ No newline at end of file
+ return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers)
+
+
+
diff --git a/models/AttMIL.py b/models/AttMIL.py
new file mode 100644
index 0000000..d1e20eb
--- /dev/null
+++ b/models/AttMIL.py
@@ -0,0 +1,79 @@
+import os
+import logging
+from argparse import ArgumentParser
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models as models
+
+import pytorch_lightning as pl
+from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
+
+
+class AttMIL(nn.Module): #gated attention
+ def __init__(self, n_classes, features=512):
+ super(AttMIL, self).__init__()
+ self.L = features
+ self.D = 128
+ self.K = 1
+ self.n_classes = n_classes
+
+ # resnet50 = models.resnet50(pretrained=True)
+ # modules = list(resnet50.children())[:-3]
+
+ # self.resnet_extractor = nn.Sequential(
+ # *modules,
+ # nn.AdaptiveAvgPool2d(1),
+ # View((-1, 1024)),
+ # nn.Linear(1024, self.L)
+ # )
+
+ # self.feature_extractor1 = nn.Sequential(
+ # nn.Conv2d(3, 20, kernel_size=5),
+ # nn.ReLU(),
+ # nn.MaxPool2d(2, stride=2),
+ # nn.Conv2d(20, 50, kernel_size=5),
+ # nn.ReLU(),
+ # nn.MaxPool2d(2, stride=2),
+
+ # # View((-1, 50 * 4 * 4)),
+ # # nn.Linear(50 * 4 * 4, self.L),
+ # # nn.ReLU(),
+ # )
+
+ # self.feature_extractor_part2 = nn.Sequential(
+ # nn.Linear(50 * 4 * 4, self.L),
+ # nn.ReLU(),
+ # )
+
+ self.attention_V = nn.Sequential(
+ nn.Linear(self.L, self.D),
+ nn.Tanh()
+ )
+
+ self.attention_U = nn.Sequential(
+ nn.Linear(self.L, self.D),
+ nn.Sigmoid()
+ )
+
+ self.attention_weights = nn.Linear(self.D, self.K)
+
+ self.classifier = nn.Sequential(
+ nn.Linear(self.L * self.K, self.n_classes),
+ )
+
+ def forward(self, x):
+ # H = kwargs['data'].float().squeeze(0)
+ H = x.float().squeeze(0)
+ A_V = self.attention_V(H) # NxD
+ A_U = self.attention_U(H) # NxD
+ A = self.attention_weights(A_V * A_U) # element wise multiplication # NxK
+ out_A = A
+ A = torch.transpose(A, 1, 0) # KxN
+ A = F.softmax(A, dim=1) # softmax over N
+ M = torch.mm(A, H) # KxL
+ logits = self.classifier(M)
+
+ return logits
\ No newline at end of file
diff --git a/models/TransMIL.py b/models/TransMIL.py
index 69089de..ca2a1fb 100755
--- a/models/TransMIL.py
+++ b/models/TransMIL.py
@@ -44,23 +44,23 @@ class PPEG(nn.Module):
class TransMIL(nn.Module):
- def __init__(self, n_classes):
+ def __init__(self, n_classes, in_features, out_features=384):
super(TransMIL, self).__init__()
- self.pos_layer = PPEG(dim=512)
- self._fc1 = nn.Sequential(nn.Linear(512, 512), nn.ReLU())
+ self.pos_layer = PPEG(dim=out_features)
+ self._fc1 = nn.Sequential(nn.Linear(in_features, out_features), nn.GELU())
# self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU())
- self.cls_token = nn.Parameter(torch.randn(1, 1, 512))
+ self.cls_token = nn.Parameter(torch.randn(1, 1, out_features))
self.n_classes = n_classes
- self.layer1 = TransLayer(dim=512)
- self.layer2 = TransLayer(dim=512)
- self.norm = nn.LayerNorm(512)
- self._fc2 = nn.Linear(512, self.n_classes)
+ self.layer1 = TransLayer(dim=out_features)
+ self.layer2 = TransLayer(dim=out_features)
+ self.norm = nn.LayerNorm(out_features)
+ self._fc2 = nn.Linear(out_features, self.n_classes)
- def forward(self, **kwargs): #, **kwargs
+ def forward(self, x): #, **kwargs
- h = kwargs['data'].float() #[B, n, 1024]
- # h = self._fc1(h) #[B, n, 512]
+ h = x.float() #[B, n, 1024]
+ h = self._fc1(h) #[B, n, 512]
#---->pad
H = h.shape[1]
@@ -83,22 +83,22 @@ class TransMIL(nn.Module):
h = self.layer2(h) #[B, N, 512]
#---->cls_token
+ print(h.shape) #[1, 1025, 512] 1025 = cls_token + 1024
+
+ # tokens = h
h = self.norm(h)[:,0]
#---->predict
logits = self._fc2(h) #[B, n_classes]
- Y_hat = torch.argmax(logits, dim=1)
- Y_prob = F.softmax(logits, dim = 1)
- results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat}
- return results_dict
+ return logits
if __name__ == "__main__":
data = torch.randn((1, 6000, 512)).cuda()
- model = TransMIL(n_classes=2).cuda()
+ model = TransMIL(n_classes=2, in_features=512).cuda()
print(model.eval())
- results_dict = model(data = data)
+ results_dict = model(data)
print(results_dict)
- logits = results_dict['logits']
- Y_prob = results_dict['Y_prob']
- Y_hat = results_dict['Y_hat']
+ # logits = results_dict['logits']
+ # Y_prob = results_dict['Y_prob']
+ # Y_hat = results_dict['Y_hat']
# print(F.sigmoid(logits))
diff --git a/models/__pycache__/AttMIL.cpython-39.pyc b/models/__pycache__/AttMIL.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ee3af4e0d05559e5bc7a4d8ed62481a3254c9f1
GIT binary patch
literal 1551
zcmYe~<>g{vU|?u4o0s&IgMr~Oh=Yt-7#J8F7#J9eTNoG^QW#Pga~Pr^G-DKF3PTE0
z4pT036f+}8j5&uTmo<tN%x27Ci(&(-XV2w`;(&^A=5jDGFf!zFMR7w#cyf88c%f{*
zDE<_N6qX!;T)`+oMh16=6xI~B7KRkIRE{j6X67j26vki%P4<@{_xWiu-r{pCN-xb#
z%_|8=EGkYd(qz2F?O&9VT9lgNl9^nh$#_fFCo?ZKu_!#TD7Uo0IlnkFFV!(GFEueI
zGcVmIC>dl9G6wk-#4lE3U|>jP0QoP9DTS?_A&oJGDTTR(Gm1HdC55$xA&RA)frTN8
z737**Y>p)*zMeiW8JHLtl9@sJp%}zwV_;wa(Z$mk7#K<zN*HSxnwb_bEo5M1s9~yM
zh-Xe=tYL^}Ndd80Yd~za67~fgCF~10YZ$VaYZ$Uv7BZEvFW_3pun?q%8>=i2RF)U2
zkEMnoixsS%4^2I*2ty4+7TZFmg&@25Qy797G@1QgHZU+SykulxV9*q}#adiikXls4
z%D}*Iiz7ZhGcU6wKK>S?&n-rmTa4bfIP>C@a}tY-Q;Tmg<>lSt3QjF7P0cIGOw75(
z1_|z4EFp<`8Mn9-OG-cz`FZhSx7dR-({uAPQ;<bNZwVrll&5B<XOtA*;sP6<nU<Ma
zq{)1Xr8qSwt%wiol(f{ulG38o;+2fIIO8F1h>u^%@XJC!BR@A)zcME=Prp3BD7&~I
zF*#K~q$n}3*w@oX*CjQzz`ZCjtwcXJKP5G%SRWEydIgn5pp-8FO4~}H7-QsPWP?CC
z77+U%8!L<!Vq#$w`Cr8k)d~xgWG0Y4DCT5fU;u@gGswU!1_p*2#sv%u85S_6FfL@Q
zWvXFVz*NJukg=AzhG79Sn8i}Vuz;n8xrRlAVIgBJYYl5HTM6p|wi?zHrWD47%(bi~
z>{%RO9&;~KEo%wu0?r!N8m5Izwd^$v3%F|77c$mz)NlkdXtMYffg-Pn0~CXdx0q8h
zb2Yhdv8LtZCzjk|FD@)iO|42T5@ujvC=vw`VxRzJEy*uR&bY-{0t%^u{NmJGjBdBs
zi}TY;auX|VG3DkKiGz&eERq1JW~{iy=n=*27#{{ALT@oTMzQ9Xmc%>WV)VVmmXn{J
zSyEgi12P;;$b$5+r{x!wCl;lEgG2-rPE34^T#Q_de2k!w<zwVw<YKDggN7<pu_j{?
zDAQ=Nff8R{Vs2`D{4K8d_}u)I(i{+*CqBNgG%*Jv1CGKXZIETW&@_>pnv)YBkK`07
zkfT8E0VipGQ1pN@ItL@4F<3-X<Q7wY@h$e8{Pgt9y!2ZfiACuJ-~t5{y|=iN^K)`i
zlS?x5^NNc=S-D6NWE}@M^z`!bia;ht3B$O0X{C8!#fdq$xDf0jWd;U@C?T+_vdrSl
z{5(BKa=gU`r3DKr!Lso=pp1|QatTvGPLx7HPJB{+PD!y|enClQZe~?#QL$c5QG9W7
yMk**oii$v<xy9xS4j-_Gko=B_3l1Ad2-<-XN--!hI2bt?c^G+^K%@|aW(EMPkda0J
literal 0
HcmV?d00001
diff --git a/models/__pycache__/TransMIL.cpython-39.pyc b/models/__pycache__/TransMIL.cpython-39.pyc
index cb0eb6fd5085feda9604ec0f24420d85eb1d939d..a329e23346d376b150c7ba792b8cd3593a128d95 100644
GIT binary patch
delta 1124
zcmZpWS}4hz$ji&cz`($;|LDpj_Kmy`85yG{zh^wd$UAutQ?Wch0|P@59|Hq}CR33B
z0|Ub?#)={lkQkT{o$SY)tjP=Fi-RP2<KvTa5{rwIQ*(0S<B=2!fs}}DzQxSLD8K_^
z%Yg{6M(N4lSd^qdW*2dQ1i?xnrh~Y=lZ{xPFe*$IXVc<V1qrBu2=&Q6Y-(&^O=6QP
z*h~b#8o-1oNM3yM9yTLa2@q2fL`Z=Muyr7J7RiFRVAWs(Y$aG9$XOuCVipDl1`b9(
zwj#O7ADKlayR*O1(Ew>u00o6;5!esrAOWzUAZ@oe3-XKOa}q04i;AQt$8$(3ySw^?
z-eQSQOE$d48IV|%n44OXT2!P93Mkg1#JrTeBGbvUIaG9HK<0qL7_8U`8q!7fAYri6
zz>WmFk9YEG4x>;eP<%l#DDK!97#N%x7#NB_FfcIGFfL$N$WY5v!?1v{hG`*VEmI9s
z7E=nNBttE833Ca{0@f0i1#C6U3mF$OHZvkbCg*TUGqY!LOzz^8t!GSO>ScoJWR_&8
zWv^k#;#|O0!=A#D!m^Ng0rx_NT8<L-ERGtE8m2U+6qtI31w2qY85Z!?Fx9Zsur6c<
zyM+%f&JSXxFr+ZFGNmx5F-bz)DgY9#<*eZhX3%6+^(z7ebP+hn85tNDUjG08|G%c-
zE!MQ0{KS$X8&J?_PVVKB<OT(2ksgR|I(Z|Pc)bOP=?Ee~AzI`FiZ3p3(k&7Lg*;<M
zkpTk(LljfIM-)?hcobJ+N=kfAYF>ItMv)On0atQPaePUBc4}VnEw-He^vsfCq=XEP
z1yC9SIiVO7?+i>Fj9koIj9QF5j2g`KT#O=&Ld;yue9T;IV$58ORRSSJiFw7oo<4eM
z`9<Z4MJYvQAlHH86ik4l7;G*mL4mviO?^DPMIez|9P#nFiJ5uv@tVxwV96{HXJ9A-
zrO+ZakoOqDNkx<K7E4N^W=Ud^FUT8gk?{pZ`AN4}BjYm?OEe`R>5ii$zbH9FFE6i1
zda@*s3ePP@w_B`vrMU%_MQ)S5c{HU!J_6fX1kOBmAg6kP2>;1_JmPv$JVmL+r8y<V
b@hO?fC19J10zlG%AOh?i1R*;46ptAIoyY3q
delta 1235
zcmZ1|*&xN6$ji&cz`(#@|1>|Ta3k+SM#fu{-!q<J<eR*QsaRfsfq|ijpMilvlc`9M
zfq~%`V?~iDNDNGfP4;6>*5m{6B|wtA@$tzyiN(dqsX00E@kk1VK}y6n-(u!r6yOE1
z<v|2kqs-)QEJ{)!vx_)Df?%Z((?MLm$wsVC7!@arvuSawfdte+gvMkaHZ?Y|Ch^G?
zY$gI=4PZhHBrh>}51WyzB#0>mBBVhC*gB9qi{wCDuxc;?wi2ul<SbbR28Lo51_lNW
zMn1M8`N<!dMJBtmztPbIX;K6Qg=i7j4;CN+u%RGrw>S&(i{o<=D^rV#WG2URNGb=V
z`h-TY#HS@2-r@{MEK1BxElDjZ(gOt)Yf)leN?wuK<k=i5Iv@+dVSEdu*a#ZZMfM=$
z!A=7^66`*{$*(z#av4GK1;rqjurn|)c!OLm#K^!<!<fP-$*_Q7Aww-w4O13l3X>#5
zEprJ|3G)J$66OW0HOvbc7cw?8Vu~;^)UwpF*09vDiZIkLOEA>3m9S;8PoBUjThElj
z+{*+tk42K9mc51{i(>(24SNb}3hP4V1zZamYB@lvYdC6{(wI_U>KPVr*Dx()tYxZU
zSin=m1TuReGnmH<7v}@9QW#PgTA5N<(m<{SspbcX)^e7xWeL=9rb7+nDiO>Qs^Nt4
zxNA6TxNEpl*yb?T^3?DIGib8=-C{{eEJ-W^MN<(t!a!m5^8f$;|23I!vE}5aXO<M-
zVvCF~D9TT|#Tpr(kyxTBe2X<LCqJ>INNaK+mt-s`f{OG(fo}$iDz?;$g2cR(TP(??
zDTzf6AaP3&;Rqr?Iitu4tSqr8JvXu97NgrO_Tv1slH9}!O<qWrVaqO0EJ`oF#h6iK
zIC%q?M7=Rc16OiRaePUBc4}U6kt@h7cM#zQlHw^!EiTO|DUMIcOfEsnMc^a=O29Ut
z$W{WS045$5E=Daz9!3pjE=CbXA!aUSK4vavF=j4CK1M!99>ywxkfOxAVqZ@my|nzI
z^2DN)B3pBiP2fZYCctS7Y&IxGgS-z-{ye-zAdy=f@$tEdnR)T?n#@Itle4*{)r*`#
zMu43Sa&ZyJ?V6I1w9QeHUzD7omzP&0GkG(&3SUtG$WN?!rMU%_MP8F1b8Ciy@(bAL
zB5)S90~zKAB0$F8V$ID@NzJ*%T2PdkS8|IbwJb5GC=?`%Y9CU-fW09$Ih)4}0Bh9^
Ae*gdg
diff --git a/models/__pycache__/model_interface.cpython-39.pyc b/models/__pycache__/model_interface.cpython-39.pyc
index 0bf337675c76de4097e7f7723b99de0120f9594c..e9d22d7ddccbb2a3d2ab59b98e066f9ac5d42285 100644
GIT binary patch
literal 14054
zcmYe~<>g{vU|@JOWkZs2C<DV|5C<8vFfcGUFfcF_w=ps>q%fo~<}gG-XvQceFrPV!
z8BDW8v4CmTC{{4d7R3&xIifgI7*d#WxN^CpxEVounR9q@d87EiY{ne^DE<_N6qX!;
zT)`+ouo!EOP_A&4FqqAjBa$l`B?@M<=ZNKsM~Q>k961uXl2MXiHfN4Ru2hs%u5^@i
zu1u5+Sd1}8HcA$3zg(_-lsr^SAy=M>fsr9sF-i$4q8z21!jQt1qmru{r3x0~&XLVk
zi&6u#8FSR5G*Wm{cyly!wW74ZVthH;xjIogU^ah_ZmwRG9+)kVqn~RKWsqwaWeDaA
z<{0G~M;YguM45p3LOG_nW>IEfws4Mlu0@mum@SfHnQIkg#mL~!kRqBQ*20h?ma36u
z-OL<i<Ia#Go+8o0kRp*PpJm(39A%fn7|ft4`4SXEewvK8SOOADGBg=)u@zJn<m8uV
zGTvhK%uP&B)nvRSkX)3SSdto_Ur>^nn^~1wq{(=T8!DEQUtFxocuP3BD8INkJ~gkT
zD8HaGz9hdW8DyBGV{v6}ZfZ$UX0lIyadB{FUV2WdPhwJPjwbUhDKH~AH$T55BQr1E
z8DxSh#DH63!TD(=A&EulsU;}l{9eT=$O0gj6qh8H#1|*$7o~z+!U}ffOGX9;22I9W
zoMo9M@x{4`IXNJ&<>sfP=71Ej6{QyErIu(i-4byuEy;IFO)M!bN_DL$DN0N($uGLa
z5tbR^3sS7fbc@}+C^5y^(f1aQYi4?C9+aoacuTN2BfmU8IWadrKQBHL8dhS?If=!^
znQ57+MgFBF1*K3=Xfod7@ky*qEdpy#%uNObJ2Ga5a*Ep+7#LC+q8L*cqL|tl(il@1
zQ<z#fqL@>dQ&?IUqF7Q`Q`lM<qF7VdQ#e`}qS#V6Q@C0fqS#ZoQ+Qe!qBv4`Q}|jK
zqBv9dQv_NVqPS87Q-oR=qPSCpQ$$)AqIgo|Qp8fkTNtBwQzTL(TNtAFQlwI(TNtAF
zQ)E(PTNt7Q+8J0Fq6C8(H05vcfdj!auOzi7EipMY8I*;g4rLHvU|`^7U|`?|WoZ@`
z28I%b62=;aW~K#93mF(0Y8VzULunQeUBX(!)XZ4J5YJY_T*DC0Uc;EhQNs|=QNmfn
zkj2r=$jDH`lENs-P{R<<Rl;4vkj2%^n8lOAB*~D%oXu3!RKkl8ZDy=xO<}2FO<}TO
zNMT*TSHrpxWHMI@TMBy%LkT~`L>C6IeF7j}3P(1}1jZtf62UAXFrPDptCy*UAzr9N
zaDi}%V3r72lp7>kBA6u#=JE70mI!8vf!VxJ8SxZeFrTlNsYEbK0);1u;H9%IWUOV&
z3jx`e!U2jSRNGQGQs8!_aHK$ON&&?l7uaO+6y6k$6u2!Ab~-D_)DpogDX<%OdO^OH
zE)mR<LE_8Suw}`CVjnCk59SG`2!MGCH7qGYk_=gjH4O1eDZ(`j@yazUDI(xdS4k1A
zVTf0S@KeOHnI<q6eJD{|pq?TQ4qc5BO*mU?0%P70s7i@!rU^_%c_rFeIw_JVQoW3|
z>?OQ4><e@kGBh)$Fr-MgGD$Mja+K(0>DO@dG1PL_a4s-d$l$^d%TvphB2&W^Z#aRe
z(4a(bfqo5ViQWRE6xoH0CB`Xo&5R{_3rs++%2Hj(RHC=QbRok+)*6O*knL_z+vUNw
zKPl19GK1RAUBX+#4YyqZ&32v|9xS%61KZAnWV<5NcBL8yknIauK=$TcfZD2@%`}0z
z$f3kM%OcA%MI}WQY&UZ)Zw)Vo>CA;7b4^mzASTuDW?98sgYDym`U#W{wem`!W~+nE
z-jX6y%Lh_bqL*cp!jz(sqM4%A%Uq(jz;+=+iC&gn3P%b@igqs(SY9JV2dBJo3YuRa
zW-R0c`KX35%RWUHY~zy}hAf9H$1JCG##;UoXArN3KSi>JA<HF2uNUNh*E~=<bI5W-
zRgt3K3(C>1DGb33ng)Kim~-;eUxLcBD3+4s+@dHh5E);VS(KRbi%~xrR5HWL5GDo&
z237_JhG0+`BF4bLkj_xU5X)B!N;eECjB}VKG8M7}GpuAP0_oCZy2YBDUz%5Pi#<8N
zs3<kLWF^B&rdv#U2Dcb9ixe0b7{J6YNBxZa+*JL_oWwl+^8BLg;)2BFRQ-^m#Jpl(
zPaj>E)YJm^qQtZkeMo%(q2e=PrJP<t<t;9ooXp~q<ow(MJ68q<hR+}eS81b`hI;Yw
znR%Hd@$q^#Ir+(nImLE*2wg8hUc1Fno|>7SQIeXX$#jc5uQWG48B{x@7T;n|EK4m)
zOi#VVnwyzil&{HHq{+a*aEmoJ5yWFE(gJw~EW}clm=h1Gyb6j^(=sb=v6iG3m&D&<
zNl7e81c%Bkj-<ro?4<m>)RzoQ44@h!GcUhN1XNxp<>zPXr)1{k>y~AfBpI6B;(*k9
zh87{3jJLQ7ic(985;OBsQ(pf6|Np;R5g(}1WGXTRS#AV!4^+LW!OOq@|Nqy7NWElW
zVPGh-0Lk;Drln;jXQt+r<fWEWsrZF?x_CP3g96wmHL)l!GcVn>A~ClhC$-oL+$6{-
zO^TAuE6YsDOpHg?8eg87l3JV^pJb2;N)NZ#iZemFUWza<FuVjM5Kujzom!NaniF4~
zSp{m8uoagSWu~OQ1i3Sc0~WcOEVo#}v3iTQ1l)Rv&jHoC#kV+0i&Nv1OH&eW2^OUm
zXCxM+#+QH+L|T4Pt|oI4s9jp*2=ccRC=A?hu@;vWq!txPf_%vVi|AV-#ffF9@fno`
zsYL~eMTxnoC8<Tlx43fh6I0^B))ujWG_ZhLMYlM?EfG+A=@w@mJmy&oiZb&`ZgGN5
zDo8BJC<3Px)`G;MlHyw|pk#N8B`ZHO?-pxOVqSV`kpaj=9&qqMT2jTgSRF%y{GD$J
z_<~zj&iOexsmUdo`FXcE9FvnvixQJ7Z*e*2XQbwNCl(YW-ePw%3<foiZm|WWCMV|P
z+~N!<N=?oz2I=4qE=Wzz1Vv&=<t-7oI7FdyX;E3~E#`nA=Ue>F`FUxjAT7R$B}JJP
zw*-=50`a*J5!U3K{Jhj#yk&_wnJMuwxAK6~E0h<-14_2>#Tlh(X*sF4*rALfIZ((O
zf<2mFS`wcIX^#~b34w(8(~9zQ;}MDVmLOC<J{8trD-s4N<N!xOd|F8nsBKbY3>Hhy
z$jnJ8O3k~)l$UplD>${VG&QdzGco5D7g7xJ7Nr)JW)`Iu$EO!1rrZ*AOi3&#$t+8C
zEK3i_&(ASRxy2HenOc5}%_lQ2HL>Uxi@U2&XptGn0#<O$++xlsO}fPbN(i^Oz?mG}
z1iZzRmVArNIX|xqWJFM^Pv|X9-^2=-nR$u1so<n{iw&Gijc&2z<rn1^fs%cZGstx2
z<iy-4!NighkfZbS;vwZpv8KQ+mg3Z$G_d9(4^ZfNf|RpnCl;j_-(o3A%`Lda4h?ij
zK(OcLr<CTT7Durbr4}1n+yW)I#9Xi+z(S@5MV1T<3{gU8ft#6^4$c@W8E<jMLsD^k
zJh+@KQUTcuDjxMf#rG3X8Og}U$ic|L$i=}1A{n_DMHmH`I2d`D#h7^*g_w94nf|jd
z@i20LX*Ncd|13;M<aZ8U0VXc692X-Sm_{OhvT+J92{7_8@-g!;vN3Wou`%*7@-Tu#
zco;>PSwQNUc^J7ESs0lZIT+a(c^KJ#bFuR=@i2*i#Qw9O%KqVC=VKIN<YDAu5(3FH
zvi;*=;bUZBWMSk3>xPhEbH20j@-Xr+NicIUR*B=N29iM;859U0466S?7!-@eQVa~B
z_Ao;Y<3h#-j46x@8Ectp7#1+qFf9ai!kGNP1-T}3Q7R}BIZN}33rkZ|s~~AZlewrG
zB+OWGixm_%#gM=US9hQoxW%59UsRr0l(G#J3*Z(ylK^9t7@8}giclO0Dp^4qz>c&5
zIg)V!1IT#`7#A|sGS)EEFg7#RGL<l8G1oAqFr_f}GSxDduw=2qc`P+dH7qsEY0SY4
znk*1+fNPPJjJKFmGIKTAZm~oCa*H)BCqJ>IC?Dh(w#1_J+{B7ojBdBsi}TY;auX{w
zSs~tJ&CDw(ExE;(lb@bhQhbXwGCm`*<Q7|Gd_hru61vw}ic3-pc7yy49(rKnW8`4s
zW8`D15=QeqNL?~0K0t8;!k}IYI|Bm)*z0{@ug5UgGS{-yvevLHU|7gd%TmLzfU$;o
zA!7`4En6*nEeANp*lJi7GSzaHFfU-KVOhvn%T>dW#ahE9$&kXx${@**!eq|C%wWR+
z7i9*Cmax@uG&9z4Wg}$SQ&?&^Q&_=}t(TRNp@y@Dy@sQPC53$sOE7~bhaaR+1XamJ
z@}P_c%H<&|nQw6v6y>MKCnXl$;>gd7Pc6t#&H#s2dTNm-3#2TV2vR=*6ws_W;LbKj
zQ5Q&P5{Lkq1CE8FZV(q-y)b9yl@x(esU|lx7LpQ6k~3~`f~ok-l!_uy7!^$inaLPm
z)DO}Lw*_qJEyhUn*x?1YE;94d<3YiH6ckOM0-b?{i;;s#fSHGri%E=8j!}fMN(L>u
zP&A?>SxyE91`b%f{b2z0Izc%qg{g(1h7r``1Lds}h6PM1%qU{aB@7D~Q&>>MSV|bO
zm}|g2>RzT=mJ$|FtTi*bFvR-RveqzUvDL6jGNiCeGSsrcc^r}qwd^%)DV(4_X$nIv
zM;=cJdkuRtV-1@$sF%Z$!nJ_2hP8&HhP{SuAq%KqTNskU4YC(uKDb{BGLtKXM~ne%
z){Pp5EN)O{0`=c}nHfPTh$n@=hG79ONOS?;LWYG*DFR?92ud&D*0zveGN@q;O3)w-
zDu6)w8Qh!(RU+x2<}^<&V+~^pgCqmAG0j*6jyX^PvXZGt2^5trAOa)+i9%4ZQl!ej
z!0-?xuExN?P-TiMhQJ9hJ~b}|TPv6>*MSl)#2xcN)?;zjBak@AS!9_HO2qih2UTT7
zAoIaxSy4O4I&c#OoDo4G{!1QQ$tzUXS}7zZCo5Ff>L@__Bns8F`qi}<E(4Xd;CR8^
z%_))vSp@d4BAf;0!XxAzNF3y{m!Jag7ArWTtYo~!T3nKto?68j0&36eCzd9M6io&-
zWck54{T6d_VhLDHayBH_2Y}4+0uiA64=EkIL7AHk)F4hQxy74YT2z#pR{||yL~`=e
z<I!7u;1Z)K2xLkp$TCiF*%2iQb`&hz$LE8G!HRFOf(r2BTb#Lx72vYG_!b+aiGGVQ
z8B*4PLiZMPVsbJhcHln0B?LDSp4#q%N)u3=FffTQiZF?Qngw9Y!N|tM!zjim#=^%a
zz*MD3Fx{e53!qR1#T2+&$Rny+C}CQ_46X`516eFJoLH&?h$y%!sNq<^x{#rU6J8at
zL5qDBaIw#dUhHeKpcl5F<g3XBi8643o)3xuL|M*~S(2Ko$pwy!qM0DMSs(&leizLK
zanVZfIUq4`fP)E8db`D625QJBf;tx90{lBD$oatK47dPi0T<s)RdQ%SjHV5x_y&b9
zy!h^61Qp*|kReWJ@eLa0>}9HD0vFe;pkfnVtb>YhrW7_5F&3~ob`&vIaIwys!jZz#
z%T&u+!cxNu3iD>h6pn1>qA8%FJcYA{VF7y$a}DbP4sgNBl>&m?pn^3IG!zILR^<VU
zaixGo!Mb=$IBS?{Seu#9WVmV=7I4?FfXu34T?moogB9W+TfoC{%!~{rpus~va1|2(
zu2lGI*d!TJ1i_UGLl%T51R5@@VFwL+PGBt3NMWeu%;PB$05>7nr5S2DO9WGd7YNm`
z)o_BU8E7T*BSi!h3J4p(l?-SoR5*nTJeb!}!;mGC!XXYCVihS7T>u)`6iE>kVOYor
znmq_+&=e!1I?-f=Bv5d=!Bt>$GBPlL3)Cue+=&IAx3SeE;1)8tB*rLN3D$T;^Wdo)
zmjk(B4kW|$1q4m!p}_Qo1Wo6InGUY4qS!$LM(`RS9h6@|IEn)#4jP!yWGb2pGIScK
zyk{;;%qdz5;(`iONK11O0|Ns_IDkq;%Ihx!xLfT(HHJ92dMa8DvVR4L02O1mm~&F|
zZn2f-WfqpEqSY})pf>w04p42Bl9^mm6c5t08bsi#(TY}rWD`LI$cMMsiW2iu@^f#o
z6sM-9K#R#J0Z`b08mX|_3X-PRfOM?|5$iz2dJq9J1QJXeKwNOL2NU2jQv|8-gQsT>
zMh1prP(91QC;_XjSOplBm_%5lScMo_{<E<Nu<(IvFD_W^#mvLV!&s$EL~cQ;zXTZ=
z7(hV;uD?{k<twOp49ej(ppFczgTn;w;4m*>PGMTe2%4~|Wv*eUVQywzz_O5`maT+g
z0V}vpV=n=9Z`h!vGDitp7O3&d3@%wYOE^Fc-*j*%XaOgvtF(X%ZXRd~hr5QmhAV}&
zm#LPggdvNk1k{-2S;&~e2G-A8!dJr!Y9NE?EPgPHAxof!7vicE_7tWRmR{x<rdqyQ
z{u<^Qen|$f>jV(4t6|6ztYMdANZ|l=vr?GhoiB(eCrGr0V}Z~@h8lLH8W&V<vedBF
z2&8b$VGCx^<n{}JlrNy#4P3}*GTq_<cSPe$QuB)Qi*B()+A76G>Y!K$^_3BA7ELxt
zasfB1!L3(N9bHro%3qu%`32eWpiwVPVQ^x<B~YAN0v=|HPt8kA%1KRuROOu@)odxL
zC5g!ykTzNm$fjNpAr2xyV*^FaAXXJfDNAl@V%{ybkW`TKiok6*R`9syEf$cKMWFOn
zv<Z~S1ksxUU@5p0A#MjZ0Kjes`vR2Squ4>tj!!Fz;!Q6~tSnAW%t-}Jciv(vF3rtN
zO}WJdvM9bJGbgq977v)6nU|88oLXG87vxNEt^gC@ticH$2n1y^b5MTJ02NA%T#Q_d
z9L#)-JWO1m^v}Y@B*MbSDgqunl49f$;bIhHlw+!rK+6Ykqfi<b($E=6P(Q!;5EH0z
z0ZRO!Y7yRu05vSg=o^9(Bdmwn$(X`{FtZahlEVV-qp^aT5}=;r0=61vP?G~ZR=`%n
z3hFP~FqE)^X6c$4Q@ErVYC#!`vxaE_YYO*5#uT0u?i4m~Qs=1UECG!vaDy5N6BuLp
zIvGlMQg}NUnwc0G+8Npz)0k5DKr_NMoS@!rEmu23J5w4{3O|w<Xa<a<hO36Vnc0OQ
z_EZdWEl(|PEnf*w4Q~xk2ZJQotQy`FZb^n({u=fYz7B?FMrno`erbkUff|7<erPX~
ziIJg%t3;rNubC0#n%J6J!5Y3AP-8`?MreWHLIzOAS|FUlnj)0Ky^y(9s0Ji1vXDW7
zp+;zd=t2+)=1DW82#YX)<QIsQh-XPGWULh`kz62EBeamQR=7rZfpiV~Y=#t(4h9fe
z!w)6}XEUUT&gB3VT@x4!{U$ILvJ@VuVOt;ro;%cuVX76W6)oYZ5t+@9B32_hn_(_9
zs0gVMg}JPTtwzL#VFF{}EkuaM2GxR^J+d`ok_;*0k_@%tH4Ir0o`fVrtwfD@3O9I0
z(J6(YRx*#LM6O1nnXyJ(nxR&*geOIEfqacvjbx2PjX0#)HG!#6C`AgC8W47X8(y+C
z3|R`GRFNW`A_Hm)mMAVzs$ok9jr~d0NM$L*Q-iEDLy8=Xw?Jhf0}n%rJV*z)hp#|J
zldOnvd#4aq4_6st%UbZN3tLkRqiI060Y<oq1S)^QEkrU*C)`9Tfq56)L;{!aq}!v(
zSOhB4A)QlP4aqW?O=OvmC^vDLUk)=riUV8-CYB~g@#p3jmw+ZnKuw~|vQ$u5G8BR8
zs48~wFol%@hHF60Inbm@u>@?^RiKu!gRzDYR6f9liXdJE7vZ1*2~DP3jCr>>3KB~|
zGk(x<Xi(V>9xG~xxr|D76`cT8BA`)2h%?|8b-*kN0j<s{NCXX26yM@Vt;j4ciO<gl
z*U&HjgEY&4My`0km2Glju3lC_dKI&RtpcLvE}~&mcO%T1py`<65^y8CiXR+lMXAN5
zIVF0@#bvh`6E&G`v4AG7Zn0&gCZ?noX{v(T322RC#4t-y0?2Eip}ry?5bH3AK(v0r
zEdl1L%z|6Y#U(|zSaY*eGK-3~f|?A>$pxjiSo2DA3o37Mf@ePq@{3b%G3RFH-C_m}
z$Fn6CmsA#{-eSrtxW!tUnO9;_1j^gDIMa&qa}$e-5-V@9faWc4v4X}JiXm+b@W?W_
zkqv5x-(o6Axy9*{Sd!=lnjgKzR+1kN@+TWOdWtori$IeZQGAe5p?FA06c?ekv~Mvd
zmL?a~f!qvfHGt<y!5%#Znxo<cH#T7DFiHrS3u%uOfkrfMF&5lnEQw-=DvV+WSqhqQ
zRR)>N0&2M2V$RG>zr~WBT3LLHtt>I8G_|-0OEdo#Xu()%P6;?AYBEET=1EY%gVQG{
zY2M-kM=iYMEC4lfKx1PJj7S4iOahDoEF7$Sj8cpO;Mp+|MltZ58E6(ufRTeyj8OqR
zBc&rCz$nBh0-DWW<YR<%3nUoDcm$Yu7^`IQmf0Bj57Zt4cL+eE?<Js-WyTtYET(2A
z(7<po!%9X!P39s{flvgRL<G+%YBGW22_gV40Kf#;qavVjz|8d0BBT|8#m7M21tk^+
zCO*u53pO1uK_;TCC<8eGJ~#v#oCFP9r!cfKr7(g?1`wIXEXlBdp$0UM#<Y;JhN*@b
z%w}H5xR9lmrG^DO^3Sr6X#q$V^Fk(&STKVos~@E3X9PtWxD782Y8x}(VuvmVxWxim
zSWt8gBnq91VuQB*88dD%mKVJS=>)kMoCqM3pP*1Z4e~i?+Mj{Zh*5;8N&zi$pynf%
z8=%AwD9AvXK_(R!fd(8IvKX=$vzTfbQy6O)%NUCEQkWJ%SO3*8E?@yo&w*Ngpm~lw
zmKw%drYflt))KY_>?y1@jI$Z$f<vX5sg^mMAy0&n0R&4pQrMaqn;0V*@|X-jGm+rV
zS`9-yBe;#h?gt(dyoK562e}U1=|`PR0L_m<CKJG60a{CvoLG=ql9&Tp?EtDOz~T@B
zlxDyc#C2GKSfz^=M4)MN$odIv?TsqX)OBi3v0jzFXCAnJ4GQAYoKyw3%$!sOzr@^B
zh5RA~n4B|ci49m(QIqQydj@!#zZjgVA@#REDEG4#7vyA?++t5pElDgXDZ0g(nOl%w
zRFadKbc+|l1UaZQC-oLr7^qlqEh@?{y2YEBS6q^qmz)XqsU`~~lYm+Ux44oMb5nET
zK~ws-Si$~-cpsc%z+vFdz`$@3lvqI35(6U_3kRbN6AL2~6B{GPe<ntz{~XX6?J7yI
ze)zIWJw*7R^mjlVeNYU6iu~dv=v-J0V;0jw#w=#gx>4}jH|83KES4<R1#F-$e<njM
zYYnJZAXm#)!&bvm#!#dMZhL|zb=jJkKz$V(hC(gSz$j>TnvtP~F^wUUp@ww{V;^X)
znz`zff+unyC?qO?YOfrH(&Eg#bcK?PRLDd~aY24w3P=PT{@{f&Tnb14RHZ0@)-dQn
zV@LtqddbgAEUMH~h|DilNKVXCNX#kDR{(d!6%z9*6+o*SP)sSwSI93*Eh;a{EJ?*;
zH@HMs$jn2rNDpoTxJdBR<h;e6nO9tpnp|>=3zQy;(m|Xe(8`8eEFde3Kttb<6a-o?
zUGx)_g4jw6QW8s2HQ6DF0My{P#g$)@ky;cF(gR%v0%5ad=7FmiR*;P0E#}0cboA8m
zmw|!dF(`F_)~+zHiZSvqvVhl8h%j<tPLQIcLzF@klqf(2DY!rZb?9rL#pObV8YWP2
zSi)2To|Y_O&H@#etP2@HNdue+IGPz@bP8iHb1iELXBHP|nJswyhj{__LeNwwa|&xV
zQ&CwBLl#d8n<PUP?*hIW=GhD>>~ooFK$DvM3qjMDY+0a%x1dfO$6O}J0$AuMGq)dT
zAzx-*N@@jYAzw~uZeDRn6h|;90*VrIK$3}hdHE%`Sdw!S3yPM2%3V;gT_vawUxcIp
zS#wmyqo14`52|?e3i8r3CoqB+eKV&NWNHc(#ey_;f(TFsDFV$H6+HlPuYicFAOhZO
zy2V_aSA2^-BQ>!kH?bfJlnmn2z;)a$_MH6m^i<F?uhfcy)S}GX)Vz{gT!|?uPz6P+
zK&gQp)GkWPOb55gU}*v|bhr$x22>5^CYD68rliFu=SH!lWu`-y(jZC+P$TLiC>?;t
zq8J!KZ2}c05mpWsIVOzZK+rHK$V~9MEa);UO{SvvAg_T4O`%&H@!%pUKK>S0JhViM
zkH5teA75CSm;)-T;^V;yy$H1Kya+s>G#8|OBZ$}xB6fj@;~?S<hyc$(f)ZB|s5~nI
zCEFs9n{V+#R|h4h=H$Ru2!XxT2+{^F0YR(jv_Z2-pg~j)W}bQmP7YQME)FRU2@XLH
zeGW^GFb-BOHVz>UKJdyr+^gh4#^GBf4_?F03aa;uazLZmpyo1YzzKvwY!C)Di;6)N
zTM0um18BtysIaSHh-U&-ibZch4FTpN7m(vYB|*_4kN~t<1Y#F~{RQ?R$Ti?lWMg7r
zC<YbD44_QQ3tBS<UnPoCbc5^$xd9ZA#h@Ht!<YqbnDsH#GSx5zGiWka#lp*}Ajpcv
zVsHruEe(^C^NYacYe_y>7(Di;07``^paKt64wq%7mMeh9@Dz$mib|6~TO1H&pC;2S
z7LWpnzaVRjz>7R?G3S>;8oH1o3hX&h>fnNVj*W-0iU;gD=&B-3wjxnb?Jfo)#6bi&
z$UwS^BtTqnFo6khsDaG{MGPy<Ob$jK7I0K)+TUU>t}HGB=g(W9&}RS{@C(G+0m?;e
z1&MhniN)aj9K~LonwXSdlvf0r1i1xj>1X6*CP9`|a^)tLfVQ%L1VG_^i@6}D<Q8*s
znb9rg08gJHP&`F(`&RnHwzYuUz)@_zm7u*Ukd>U=r6rj;#d=^8JogaAS(2HXs|T7#
zDFU@si#S0J15IldftD13yKc8Qz-t2a^74v6(=SoNFs@!&X&!i?CTRH~f?Wh2=?w(+
zY{bDzp$qqrlo#=UtO0GexFuLn304%J16o=PT4Q^QsUYW;P#JhZC!(2|T2#aj3NLQ3
zYKUwRD5-(^OGU{U5D`6a<y~9^nmD+{1MU=pRsf_HK^9+v7gk0Iz^nwVjsy(>Lslq?
zf?O<vE(ZxLJ@AH@BG6)oC>0z^z*fh@2G&7?+)>=gi3KI4xq6^BSrKSx;ugOT%$vU8
z1&rW=;T9W|14=`XWC~7PU;>o5ZgJS;f|c8WYK~$z(5fMDZ3<eQ2*RN070|jw5Dmg2
h%p8n7j6#ecl7|&cg4Px?2{3}z{qZmf)iZ!#EC3nAaHaqN
delta 6350
zcmaEsJ1K%Mk(ZZ?fq{WR?9h}X8;6N}GK_W;wZrOJa`<xvq68QjQW#QLa|CmRqJ+R~
zwjAMHkth)`n>|M~S1d{l%;v}u%@vOl&y|Rh$d!ze1dB1|NJUAdFr+Z%NaxB#$w0+q
zb7hzq7#VWqqU50>3Q-Cv3@MyBQn`vziV(5-9Hl7b6s{ER9F<(vC{?f^PmWrydXzet
z&6}f<s~M#UX7lA}<!VQ1=juf1fcgA6y19B$dW;P23@HLBf-MXwf~m?``pwKy2JQ?g
zLMg&63@O5?GFgVr%uz-ujKK_=A}>K9;HSxWizOhjBtw(&7CT!(WkF7U$>i&d3n#ZT
z`A?Q()@00?oM$9IIi1<jRy0Mdg&~T!oq>fRiZ7T!Q~Va6Z+=Q@j%Qv;YEfEZa%wUQ
z$V4dSV_;z5Wnf_7W?*0_PGO$>pV^|mh9!kjlA(qno->6hg}H^HgsX-ji=&y*g`t_T
zh9RCC#7kkxW+}Q*!jr`d=CP)*^)l5k#PgQ$EZ{5Q$>Il#vV%lRc(Me*JdR$*5}qtU
zFq;!9Bb34k=5zHjmGEQ<gL(BFDG;s*m<#5mvn^z-Wz7o#S(w6-f?-_>OA6ew6qXdI
zRVgefNEW5AK+P3S;e@c$SwW_j@MMXC9l_BH@}5|I4QrNo3U>+*m@5J1@j`i$H7qH7
zk_=f=H4O36Df~4I@iH|mDFWcol1&k;VThN5@Kc1cnI<q6%_)&zppYWG5M;1oi4vTx
zJb^Lq2vns=Hq!*ABB>IUEY%dz6tUiV##*)#wHmes>I)f~8B-Wi#9Nsp8EV-}G_o{n
z*!vi2Ichi-Xf0%LVTk3a<xG*N;f&Xwz*u;vL}P(w4M&N_0-Y4eg^VS-DN@ahB^nF#
zK+ejNTgX(Ru|R(z!$Q^?hIo+mZcyu`!PZYGQOPm@ThCZu%T=OQ!v(iq2F-fz8g2~h
znF^!8)^j6SFAKF^u7&|*{X!Oy#d#N?*2;seeN$qXWt3%{qL88pww$?^r>33<i}5?a
zK2d@gRl}2I5^oB&kO%53P~ub0D}frWoXs?WxhNq;qLvq=szf8pEJY<nHAO8&y_dN}
zV}bcXh7ye|ixid=mK2R%Ca}C}iY7$f28#mS6g1yJ)GXu#`6-1Vm_bv^?-p}Te)?oS
zQIpBh9I9**3=9lK5|i6Ff*6e_-{eRWQDk6XxW$^Am|T>v$yB5Sl2)7?%xS`SYjOvt
zH7^eX1A``0kse4~Yw`t7QEm|PB?AisLy_U+51c71%nS?+lY_Z5xHT<ru@;vWq!twk
zPrk-wz{on8om)YOBPTySz96wA;}&aCVqSV`k;ddm?ySl8xZ@<nK!$39?B~fZEs0M{
zO)M!bN-ZuDm>kNZAtwlu;K&7)5%FmyMW7O)NCzawl9O3nGI=UbETjJ9&pdY8ntZod
zic@paZn1(D7ukV~wFjwW%}y*zFTTZ|n_5zonOt0?GkG?zbi6(T149&RQEIWNL6H%N
zEd;eWKDDBxC^5MtGcSE5<1Nnk_~e|#;^O%DB5<I|fQ+?fU|>)Ml}YD71r;M7BL^b~
zBNr<hh-4OH=3*3LoUFp9%*Z&|gHLvHI$wExl{jiKrxzcenU`4-AD;{gIFKnI3@Ynb
z85kIxL1tfKU|^_WjA5>2u4SoZtzlWfu#lmarG|L{V+!L!Mlg@Lh9Qe7g(;h<$f<@Q
zi#d%Ulc9!r31c5)En5v+7E3K#9#;*+0_Ga#`h{Sj1*{;gwQMyES!^|IDa^f$wd^(Q
zDJ+tpY|Ea)D#=jGS;L;fCdp6>(#T%JQNyy3sg}EhV*zIk%R<Ilo*ITMt{NUmh7@*I
z21$k#4s!-(1{(&rC?`m?gu8~TnX!f^n+Yn*T~p6l!&SqQ!Zn8_m_d`<uZR&8irNee
z3@aInG(d60mXn{JSyFt9Ei%5KC_m{IYh-*zVo4D=ZGaMAkvJ#?vOr?B2&B6hq>VKv
zF)1}iljjy|T26jq$u0Kc!qU{#s?;J-dMXM9DG|&`EH0^!&&W(kNzIEdE=ep&y~SCY
z2UC~{G7Fqm5<sfi5{uGv6DuH&V$RGf(c~@407<YWC6**-+~NdN@tG+QB}FA5ImY-~
z9MHs5e2b?jwYW5=q&Pk$Gr6Rw5M+E2Llgr;5y-({_upcSgrp^Sg1W_9Qk0mPmzkFy
zUtE$}Fbx#_pghRH%*DvT$j8jZD8R_U#Kp+P%)<mqY8=cwoLo#|jB-p>GL!EMOKc7m
z3S-m-2Mfr|l}traV0TrZ*;6FTz`(E{<RrPtw}kUGz)k|`C~5`?f>~e!oXSCp4}iqw
zCKrlW*nzwXF%OsGLtw=q0u*H{8E-MCWadKRP<--JkvI)ekQ~?oNjMA4h1+)qBrZ2O
zTU2Fos-V0Pe-SuVn3EGrZm|}ZBqnDURf1GFfd~f>0ZKhZr6Aw378GUXl}tV=swX4@
zl5qhMWuRokS(GuERm{Q&oSHxh=@xThax%m(aO-ag!Q&vcAU`=HJ~c08@8o>3uzHYZ
z;JI=MQMqyfQw=jT%Q3^V97_#b3KKNTF+;N)DC>c;94PB?m9Q>g17|Z(BZj?(TaqD#
zjg_Gul-r;=i6ILj%mK=0HCzig7BbXu!}A$O4J4a!g0mSHG@BKG!U38NVaX625+Hwp
z^VLcgSUQEtgW`H6CpcV->Oo-$&PzoVAfv!J1{5ujr~pO4EtbrZ)ZCRk;H*+q4N_GD
zB0zHBEK*bp;?{wPdJuu;oCc5>I3~aZC^g*TFH6kHOi3)sRLKWr+th-qpj-h;*9=Tt
zOdO1Sj9iRdOg!MCmxFckCvgYHw8_R2sr5AsSu81>C9E|}HLT4{DO_O4-OI`dDyCRc
zcxo6Hu+=cvur6R<$gmJyo}-2VRMdgwYFMGNyuHj|y<8w&DSRpXAbtrbj(9-L&e#BO
zfy!IMCdrT@AX(2)%U;8f1>p)pqnR1p-eO8&sO8M#DdDT(XlATomu9HtEa6WPS|Cuv
zR>N7tQNzBFrJ0eDq3}|Qa0*ik1HuY$%NEp@6-?m*o0L?;kR_DDAr5K>3zY~j5CQW=
zL>Lw_a)O%j!3>(Bev{X$nCgN<8XT&)a_xIi;L9;EFjSdO{v|FoIbW7@vX@k#MiWRJ
zug;Gk6>^gmrOh;2K}@`gKY>($6=%1Bn0OU`0jU7FsR)!vqd3YEbK-OIi;F{w`ao4R
zJBXiHnykrGGy$Y?A}IeemnG&D^?<nG)D6n#MV*riMMNgci>lNUE_(PG85oMdc3Of8
z4N-8>Q#1)=;A9X1E_lEt3mb?FFJ5ke+cx0R-XA0}9YmlPP~cQQ1tc>SL<E8ekTY+w
z6(#1S<mZCgE~zP+T;LoOB>?gts3wTdFD)r3Eh#PnN8=2Tu9+ZW7KoS)B0z?KGjY)z
z5EmS(U;><pL|~a{GM_Au<U3Ff0yUl(7$ulQ7)6*w82MPG7!_D1o6Gv47IVueD&}ff
zQkZ&~z-<en3%Ytxk(R<venF?n1g^Cq<qbHzz!f66jsu4hC@DdLY91)4=7WfZAOaj-
zhF}sLYMdph#U<dHSd5W@p%_}`fNElJnFDF=aWGE~RnX-JHAp#NO*S{i$rBZPCjV4W
zt!KAks9^&YBQ6ZFS}{zu?6n*v95w8-8B#cEIA$}<Wd`N+8jb~=3mF!0Eo7)+t6{fc
zD6B%{`B=YNQ16qshEtLug%i>u%YyK@BpGVCYq(OlvzdzSr7+a;fXfu_X2u#WX@**!
z5{?v}=mq>WoHaZ(+%;U#LS;h=FOEV*poR(5afFs@f(wLd*wPsnGS>3e@GcOpVaO6m
z5s+qB$OtW01VQDBkzNtu;z1jhp{fig=V;1K_L35r{8HJNh_XTltP@nQgc2?*^uUTi
z@d%FKc2EJq0gl1M(&Q>;1zQDZ69lA$5k!zuDtN(6zr_MF6I{SlaY5R6ddbCQn%K(;
zc<E3S04keULG@#C5vTyT#hhGFdW$u$G`FC#z6iaX067m_LV!w<TTBHhw>VuAOA_6R
z5_3~;v6bY<Cl{9$IfBY3K1fl)2Pq-qAucK|LM<b1F(;NL7iEK-z>=L>S$vDFEHS4v
zwYUgUb|izW0~eH_!s8Yj#1Qm`5g*9o@$g2G7AQ|>fx0V<&@zNsfKh;jgI$18h*5}%
zj}Zhp7}=P37{wUHm?o!bMmm57b&5cpu_93GL6ZsWc8CBt+kpvi-V;gA&r8cpFD*)q
z&o3y+%+0JyEh-M2{9n^>vZnSF#@fmIv@O6{Rg>`+S8`%OW=UcWNDZzAWh^-UCaddI
z2!O1<#gYdKNtUeq%)H5Kbvz6~jr*cqpa9}ZPRvcsi3iKGf_rOd0a0Yjz`zi`SxvW(
z(Fc^&{4_alv1jHL7o;Ya+~P`4Edg~i3sRGdK;6SzEFfOdPEZJgEH2szV(kT4#8z66
zl30?e$qosH!;|z4M8E+6CP1N3bdZ68A#<{een35_FHpm<fT4zQ0pmi38m5IzwJaq}
zS<E#o3s@F1EMQ&8Py*`Cu{Sfq=oH3Y=2})zmu(>fBSQ&S7IzIRXs9fOrI&dD&q4-B
zYX#KV%i>$WU&B0`A%$%&Qw_)zfrX%!4s!~7Hd9eY3C9A#8d-)cAyAn$mkCq=lyEE%
zu3-bQz@rMBleG-=CdV3xyA`bjwI4Y2lXK%iiA7VS$QR^YaLaKmh_wzxtOgNlKm<J5
z-(oJvDY?Z4iSfyM4QzBk9=^q1oLH8cmYEJt>bF>vbMlMf;f1IGdnd~oO6h<@4NQQ8
z4wP}qK}l2qoJ0i}RTw!KMVRH7I6%GfBttbu{>f7fRqH{KRMZG^%v=z$2t+Ig5i3B%
z77(!wM1Ue493Mp>lW*}t#{-g6b8=u~0$|JYK#IXG5N2Rt0QKIALA4GCGfyFhI)?;@
zB!>ov2!}cc8;1}FUy<<S03#2^*vT`E%q+KpvNLm$703XPt8TFtXCxM+!g9%dkT}?M
zusvWaK+4(}C$kxEss~AHGJ})SE!NDug3^*(%=x7yXleHrds==`d16sY7bs{z*^z;b
zhp~z$EHkxS4=Sa}R&*O=*BuaX7es)pDS8ZI-2)L|`@sa*IbfIg#WFB3w89MKVB}#b
z0*PoE-(oJVEG`1Y_AREof}$fJ1NMQ4rQkeJkeHW}SX=~398v7WsfkJXMR`S+874QF
zDApHU1eM|3zLoy45}*j=;wU!XN}v4V;v!II>=t)vNoG#59+)fwwaB74OEPnF^+3(U
zB6g5PplZ2@6T||I(}SCHw>V1ji;^?+^74uzKz0bjxO!=&dC4W2`FV*sw<hbGa&z5c
zbSnZ4!xV8(PBGQ52aU|!;s&cO%Ph{!&jXJ{MF}B^=z+(bii`L`{^KbrO3W)x%P-1J
zEh;VoCB-7pNLCbga$-SAX|7&kNl8(W5J)AzPiA^X38?Syo1apelUf8yKDXGQ98g#m
zNif8KhU-8+0f#;);BRr*<bu`PfvU7(3kC)T@W=q9Kg7tv$iv9P%mX4Bc^E~QL3|;Q
TD3}Ej7hvLH5~^pYXGjGA2Ue+&
diff --git a/models/model_interface.py b/models/model_interface.py
index 60b5cc7..b3a561e 100755
--- a/models/model_interface.py
+++ b/models/model_interface.py
@@ -7,6 +7,8 @@ import pandas as pd
import seaborn as sns
from pathlib import Path
from matplotlib import pyplot as plt
+import cv2
+from PIL import Image
#---->
from MyOptimizer import create_optimizer
@@ -20,7 +22,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
+from torchmetrics.functional import stat_scores
from torch import optim as optim
+# from sklearn.metrics import roc_curve, auc, roc_curve_score
+
#---->
import pytorch_lightning as pl
@@ -29,6 +34,10 @@ from torchvision import models
from torchvision.models import resnet
from transformers import AutoFeatureExtractor, ViTModel
+from pytorch_grad_cam import GradCAM, EigenGradCAM
+from pytorch_grad_cam.utils.image import show_cam_on_image
+from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
+
from captum.attr import LayerGradCam
class ModelInterface(pl.LightningModule):
@@ -41,11 +50,20 @@ class ModelInterface(pl.LightningModule):
self.loss = create_loss(loss)
# self.asl = AsymmetricLossSingleLabel()
# self.loss = LabelSmoothingCrossEntropy(smoothing=0.1)
-
# self.loss =
+ # print(self.model)
+
+
+ # self.ecam = EigenGradCAM(model = self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform)
self.optimizer = optimizer
self.n_classes = model.n_classes
- self.log_path = kargs['log']
+ print(self.n_classes)
+ self.save_path = kargs['log']
+ if Path(self.save_path).parts[3] == 'tcmr':
+ temp = list(Path(self.save_path).parts)
+ # print(temp)
+ temp[3] = 'tcmr_viral'
+ self.save_path = '/'.join(temp)
#---->acc
self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
@@ -53,6 +71,7 @@ class ModelInterface(pl.LightningModule):
#---->Metrics
if self.n_classes > 2:
self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted')
+
metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes,
average='micro'),
torchmetrics.CohenKappa(num_classes = self.n_classes),
@@ -67,6 +86,7 @@ class ModelInterface(pl.LightningModule):
else :
self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'weighted')
+
metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2,
average = 'micro'),
torchmetrics.CohenKappa(num_classes = 2),
@@ -76,6 +96,8 @@ class ModelInterface(pl.LightningModule):
num_classes = 2),
torchmetrics.Precision(average = 'macro',
num_classes = 2)])
+ self.PRC = torchmetrics.PrecisionRecallCurve(num_classes = self.n_classes)
+ # self.pr_curve = torchmetrics.BinnedPrecisionRecallCurve(num_classes = self.n_classes, thresholds=10)
self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes)
self.valid_metrics = metrics.clone(prefix = 'val_')
self.test_metrics = metrics.clone(prefix = 'test_')
@@ -146,28 +168,40 @@ class ModelInterface(pl.LightningModule):
nn.Linear(1024, self.out_features),
nn.ReLU(),
)
+ # print(self.model_ft[0].features[-1])
+ # print(self.model_ft)
+ if model.name == 'TransMIL':
+ target_layers = [self.model.layer2.norm] # 32x32
+ # target_layers = [self.model_ft[0].features[-1]] # 32x32
+ self.cam = GradCAM(model=self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform) #, reshape_transform=self.reshape_transform
+ # self.cam_ft = GradCAM(model=self.model, target_layers = target_layers_ft, use_cuda=True) #, reshape_transform=self.reshape_transform
+ else:
+ target_layers = [self.model.attention_weights]
+ self.cam = GradCAM(model = self.model, target_layers = target_layers, use_cuda=True)
+
+ def forward(self, x):
+
+ feats = self.model_ft(x).unsqueeze(0)
+ return self.model(feats)
+
+ def step(self, input):
+
+ input = input.squeeze(0).float()
+ logits = self(input)
+
+ Y_hat = torch.argmax(logits, dim=1)
+ Y_prob = F.softmax(logits, dim=1)
+
+ return logits, Y_prob, Y_hat
def training_step(self, batch, batch_idx):
#---->inference
- data, label, _ = batch
+
+ input, label, _= batch
label = label.float()
- data = data.squeeze(0).float()
- # print(data)
- # print(data.shape)
- if self.backbone == 'dino':
- features = self.model_ft(**data)
- features = features.last_hidden_state
- else:
- features = self.model_ft(data)
- features = features.unsqueeze(0)
- # print(features.shape)
- # features = features.squeeze()
- results_dict = self.model(data=features)
- # results_dict = self.model(data=data, label=label)
- logits = results_dict['logits']
- Y_prob = results_dict['Y_prob']
- Y_hat = results_dict['Y_hat']
+
+ logits, Y_prob, Y_hat = self.step(input)
#---->loss
loss = self.loss(logits, label)
@@ -183,6 +217,14 @@ class ModelInterface(pl.LightningModule):
# Y = int(label[0])
self.data[Y]["count"] += 1
self.data[Y]["correct"] += (int(Y_hat) == Y)
+ self.log('loss', loss, prog_bar=True, on_epoch=True, logger=True)
+
+ if self.current_epoch % 10 == 0:
+
+ grid = torchvision.utils.make_grid(images)
+ # log input images
+ # self.loggers[0].experiment.add_figure(f'{stage}/input', , self.current_epoch)
+
return {'loss': loss, 'Y_prob': Y_prob, 'Y_hat': Y_hat, 'label': label}
@@ -212,18 +254,10 @@ class ModelInterface(pl.LightningModule):
def validation_step(self, batch, batch_idx):
- data, label, _ = batch
-
+ input, label, _ = batch
label = label.float()
- data = data.squeeze(0).float()
- features = self.model_ft(data)
- features = features.unsqueeze(0)
-
- results_dict = self.model(data=features)
- logits = results_dict['logits']
- Y_prob = results_dict['Y_prob']
- Y_hat = results_dict['Y_hat']
-
+
+ logits, Y_prob, Y_hat = self.step(input)
#---->acc log
# Y = int(label[0][1])
@@ -237,18 +271,23 @@ class ModelInterface(pl.LightningModule):
def validation_epoch_end(self, val_step_outputs):
logits = torch.cat([x['logits'] for x in val_step_outputs], dim = 0)
- # probs = torch.cat([x['Y_prob'] for x in val_step_outputs], dim = 0)
probs = torch.cat([x['Y_prob'] for x in val_step_outputs])
max_probs = torch.stack([x['Y_hat'] for x in val_step_outputs])
- # target = torch.stack([x['label'] for x in val_step_outputs], dim = 0)
target = torch.cat([x['label'] for x in val_step_outputs])
target = torch.argmax(target, dim=1)
#---->
# logits = logits.long()
# target = target.squeeze().long()
# logits = logits.squeeze(0)
+ if len(target.unique()) != 1:
+ self.log('val_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True)
+ else:
+ self.log('val_auc', 0.0, prog_bar=True, on_epoch=True, logger=True)
+
+
+
self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True)
- self.log('val_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True)
+
# print(max_probs.squeeze(0).shape)
# print(target.shape)
@@ -276,24 +315,61 @@ class ModelInterface(pl.LightningModule):
random.seed(self.count*50)
def test_step(self, batch, batch_idx):
-
- data, label, _ = batch
+ torch.set_grad_enabled(True)
+ data, label, name = batch
label = label.float()
+ # logits, Y_prob, Y_hat = self.step(data)
+ # print(data.shape)
data = data.squeeze(0).float()
- features = self.model_ft(data)
- features = features.unsqueeze(0)
+ logits = self(data).detach()
+
+ Y = torch.argmax(label)
+ Y_hat = torch.argmax(logits, dim=1)
+ Y_prob = F.softmax(logits, dim = 1)
+
+ #----> Get Topk tiles
+
+ target = [ClassifierOutputTarget(Y)]
+
+ data_ft = self.model_ft(data).unsqueeze(0).float()
+ # data_ft = self.model_ft(data).unsqueeze(0).float()
+ # print(data_ft.shape)
+ # print(target)
+ grayscale_cam = self.cam(input_tensor=data_ft, targets=target)
+ # grayscale_ecam = self.ecam(input_tensor=data_ft, targets=target)
+
+ # print(grayscale_cam)
- results_dict = self.model(data=features, label=label)
- logits = results_dict['logits']
- Y_prob = results_dict['Y_prob']
- Y_hat = results_dict['Y_hat']
+ summed = torch.mean(torch.Tensor(grayscale_cam), dim=2)
+ print(summed)
+ print(summed.shape)
+ topk_tiles, topk_indices = torch.topk(summed.squeeze(0), 5, dim=0)
+ topk_data = data[topk_indices].detach()
+
+ # target_ft =
+ # grayscale_cam_ft = self.cam_ft(input_tensor=data, )
+ # for i in range(data.shape[0]):
+
+ # vis_img = data[i, :, :, :].cpu().numpy()
+ # vis_img = np.transpose(vis_img, (1,2,0))
+ # print(vis_img.shape)
+ # cam_img = grayscale_cam.squeeze(0)
+ # cam_img = self.reshape_transform(grayscale_cam)
+
+ # print(cam_img.shape)
+
+ # visualization = show_cam_on_image(vis_img, cam_img, use_rgb=True)
+ # visualization = ((visualization/visualization.max())*255.0).astype(np.uint8)
+ # print(visualization)
+ # cv2.imwrite(f'{test_path}/{Y}/{name}/gradcam.jpg', cam_img)
#---->acc log
Y = torch.argmax(label)
self.data[Y]["count"] += 1
self.data[Y]["correct"] += (Y_hat.item() == Y)
- return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label}
+ return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': name, 'topk_data': topk_data} #
+ # return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': name} #, 'topk_data': topk_data
def test_epoch_end(self, output_results):
probs = torch.cat([x['Y_prob'] for x in output_results])
@@ -301,7 +377,8 @@ class ModelInterface(pl.LightningModule):
# target = torch.stack([x['label'] for x in output_results], dim = 0)
target = torch.cat([x['label'] for x in output_results])
target = torch.argmax(target, dim=1)
-
+ patients = [x['name'] for x in output_results]
+ topk_tiles = [x['topk_data'] for x in output_results]
#---->
auc = self.AUROC(probs, target.squeeze())
metrics = self.test_metrics(max_probs.squeeze() , target)
@@ -312,9 +389,41 @@ class ModelInterface(pl.LightningModule):
# self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True)
- # print(max_probs.squeeze(0).shape)
- # print(target.shape)
- # self.log_dict(metrics, logger = True)
+ #---->get highest scoring patients for each class
+ test_path = Path(self.save_path) / 'most_predictive'
+ topk, topk_indices = torch.topk(probs.squeeze(0), 5, dim=0)
+ for n in range(self.n_classes):
+ print('class: ', n)
+ topk_patients = [patients[i[n]] for i in topk_indices]
+ topk_patient_tiles = [topk_tiles[i[n]] for i in topk_indices]
+ for x, p, t in zip(topk, topk_patients, topk_patient_tiles):
+ print(p, x[n])
+ patient = p[0]
+ outpath = test_path / str(n) / patient
+ outpath.mkdir(parents=True, exist_ok=True)
+ for i in range(len(t)):
+ tile = t[i]
+ tile = tile.cpu().numpy().transpose(1,2,0)
+ tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255
+ tile = tile.astype(np.uint8)
+ img = Image.fromarray(tile)
+
+ img.save(f'{test_path}/{n}/{patient}/{i}_gradcam.jpg')
+
+
+
+ #----->visualize top predictive tiles
+
+
+
+
+ # img = img.squeeze(0).cpu().numpy()
+ # img = np.transpose(img, (1,2,0))
+ # # print(img)
+ # # print(grayscale_cam.shape)
+ # visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True)
+
+
for keys, values in metrics.items():
print(f'{keys} = {values}')
metrics[keys] = values.cpu().numpy()
@@ -329,16 +438,35 @@ class ModelInterface(pl.LightningModule):
print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+ #---->plot auroc curve
+ # stats = stat_scores(probs, target, reduce='macro', num_classes=self.n_classes)
+ # fpr = {}
+ # tpr = {}
+ # for n in self.n_classes:
+
+ # fpr, tpr, thresh = roc_curve(target.cpu().numpy(), probs.cpu().numpy())
+ #[tp, fp, tn, fn, tp+fn]
+
+
self.log_confusion_matrix(probs, target, stage='test')
#---->
result = pd.DataFrame([metrics])
- result.to_csv(self.log_path / 'result.csv')
+ result.to_csv(Path(self.save_path) / f'test_result.csv', mode='a', header=not Path(self.save_path).exists())
+
+ # with open(f'{self.save_path}/test_metrics.txt', 'a') as f:
+
+ # f.write([metrics])
def configure_optimizers(self):
# optimizer_ft = optim.Adam(self.model_ft.parameters(), lr=self.optimizer.lr*0.1)
optimizer = create_optimizer(self.optimizer, self.model)
return optimizer
+ def reshape_transform(self, tensor, h=32, w=32):
+ result = tensor[:, 1:, :].reshape(tensor.size(0), h, w, tensor.size(2))
+ result = result.transpose(2,3).transpose(1,2)
+ # print(result.shape)
+ return result
def load_model(self):
name = self.hparams.model.name
@@ -372,18 +500,33 @@ class ModelInterface(pl.LightningModule):
args1.update(other_args)
return Model(**args1)
+ def log_image(self, tensor, stage, name):
+
+ tile = tile.cpu().numpy().transpose(1,2,0)
+ tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255
+ tile = tile.astype(np.uint8)
+ img = Image.fromarray(tile)
+ self.loggers[0].experiment.add_figure(f'{stage}/{name}', img, self.current_epoch)
+
+
def log_confusion_matrix(self, max_probs, target, stage):
confmat = self.confusion_matrix(max_probs.squeeze(), target)
+ print(confmat)
df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
- plt.figure()
+ # plt.figure()
fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
# plt.close(fig_)
- # plt.savefig(f'{self.log_path}/cm_e{self.current_epoch}')
- self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+ # plt.savefig(f'{self.save_path}/cm_e{self.current_epoch}')
+
- if stage == 'test':
- plt.savefig(f'{self.log_path}/cm_test')
- plt.close(fig_)
+ if stage == 'train':
+ # print(self.save_path)
+ # plt.savefig(f'{self.save_path}/cm_test')
+
+ self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+ else:
+ fig_.savefig(f'{self.save_path}/cm_test.png', dpi=400)
+ # plt.close(fig_)
# self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch)
class View(nn.Module):
diff --git a/test_visualize.py b/test_visualize.py
new file mode 100644
index 0000000..7ef56d3
--- /dev/null
+++ b/test_visualize.py
@@ -0,0 +1,148 @@
+import argparse
+from pathlib import Path
+import numpy as np
+import glob
+
+from sklearn.model_selection import KFold
+
+from datasets.data_interface import DataInterface, MILDataModule, CrossVal_MILDataModule
+from models.model_interface import ModelInterface
+import models.vision_transformer as vits
+from utils.utils import *
+
+# pytorch_lightning
+import pytorch_lightning as pl
+from pytorch_lightning import Trainer
+import torch
+from train_loop import KFoldLoop
+
+#--->Setting parameters
+def make_parse():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--stage', default='train', type=str)
+ parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str)
+ parser.add_argument('--version', default=0,type=int)
+ parser.add_argument('--epoch', default='0',type=str)
+ parser.add_argument('--gpus', default = 2, type=int)
+ parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str)
+ parser.add_argument('--fold', default = 0)
+ parser.add_argument('--bag_size', default = 1024, type=int)
+
+ args = parser.parse_args()
+ return args
+
+#---->main
+def main(cfg):
+
+ torch.set_num_threads(16)
+
+ #---->Initialize seed
+ pl.seed_everything(cfg.General.seed)
+
+ #---->load loggers
+ # cfg.load_loggers = load_loggers(cfg)
+
+ # print(cfg.load_loggers)
+ # save_path = Path(cfg.load_loggers[0].log_dir)
+
+ #---->load callbacks
+ # cfg.callbacks = load_callbacks(cfg, save_path)
+
+ home = Path.cwd().parts[1]
+ DataInterface_dict = {
+ 'data_root': cfg.Data.data_dir,
+ 'label_path': cfg.Data.label_file,
+ 'batch_size': cfg.Data.train_dataloader.batch_size,
+ 'num_workers': cfg.Data.train_dataloader.num_workers,
+ 'n_classes': cfg.Model.n_classes,
+ 'backbone': cfg.Model.backbone,
+ 'bag_size': cfg.Data.bag_size,
+ }
+
+ dm = MILDataModule(**DataInterface_dict)
+
+
+ #---->Define Model
+ ModelInterface_dict = {'model': cfg.Model,
+ 'loss': cfg.Loss,
+ 'optimizer': cfg.Optimizer,
+ 'data': cfg.Data,
+ 'log': cfg.log_path,
+ 'backbone': cfg.Model.backbone,
+ }
+ model = ModelInterface(**ModelInterface_dict)
+
+ #---->Instantiate Trainer
+ trainer = Trainer(
+ num_sanity_val_steps=0,
+ # logger=cfg.load_loggers,
+ # callbacks=cfg.callbacks,
+ max_epochs= cfg.General.epochs,
+ min_epochs = 200,
+ gpus=cfg.General.gpus,
+ # gpus = [0,2],
+ # strategy='ddp',
+ amp_backend='native',
+ # amp_level=cfg.General.amp_level,
+ precision=cfg.General.precision,
+ accumulate_grad_batches=cfg.General.grad_acc,
+ # fast_dev_run = True,
+
+ # deterministic=True,
+ check_val_every_n_epoch=10,
+ )
+
+ #---->train or test
+ log_path = cfg.log_path
+ # print(log_path)
+ # log_path = Path('lightning_logs/2/checkpoints')
+ model_paths = list(log_path.glob('*.ckpt'))
+
+ if cfg.epoch == 'last':
+ model_paths = [str(model_path) for model_path in model_paths if f'last' in str(model_path)]
+ else:
+ model_paths = [str(model_path) for model_path in model_paths if f'epoch={cfg.epoch}' in str(model_path)]
+
+ # model_paths = [str(model_path) for model_path in model_paths if f'epoch={cfg.epoch}' in str(model_path)]
+ # model_paths = [f'lightning_logs/0/.ckpt']
+ # model_paths = [f'{log_path}/last.ckpt']
+ if not model_paths:
+ print('No Checkpoints vailable!')
+ for path in model_paths:
+ # with open(f'{log_path}/test_metrics.txt', 'w') as f:
+ # f.write(str(path) + '\n')
+ print(path)
+ new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg)
+ trainer.test(model=new_model, datamodule=dm)
+
+ # Top 5 scoring patches for patient
+ # GradCam
+
+
+if __name__ == '__main__':
+
+ args = make_parse()
+ cfg = read_yaml(args.config)
+
+ #---->update
+ cfg.config = args.config
+ cfg.General.gpus = [args.gpus]
+ cfg.General.server = args.stage
+ cfg.Data.fold = args.fold
+ cfg.Loss.base_loss = args.loss
+ cfg.Data.bag_size = args.bag_size
+ cfg.version = args.version
+ cfg.epoch = args.epoch
+
+ log_path = Path(cfg.General.log_path) / str(Path(cfg.config).parent)
+ Path(cfg.General.log_path).mkdir(exist_ok=True, parents=True)
+ log_name = f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
+ task = '_'.join(Path(cfg.config).name[:-5].split('_')[2:])
+ # task = Path(cfg.config).name[:-5].split('_')[2:][0]
+ cfg.log_path = log_path / f'{cfg.Model.name}' / task / log_name / 'lightning_logs' / f'version_{cfg.version}' / 'checkpoints'
+
+
+
+ #---->main
+ main(cfg)
+
\ No newline at end of file
diff --git a/train.py b/train.py
index 036d5ed..5e30394 100644
--- a/train.py
+++ b/train.py
@@ -5,7 +5,7 @@ import glob
from sklearn.model_selection import KFold
-from datasets.data_interface import DataInterface, MILDataModule
+from datasets.data_interface import DataInterface, MILDataModule, CrossVal_MILDataModule
from models.model_interface import ModelInterface
import models.vision_transformer as vits
from utils.utils import *
@@ -13,18 +13,23 @@ from utils.utils import *
# pytorch_lightning
import pytorch_lightning as pl
from pytorch_lightning import Trainer
-from pytorch_lightning.loops import KFoldLoop
import torch
+from train_loop import KFoldLoop
#--->Setting parameters
def make_parse():
parser = argparse.ArgumentParser()
parser.add_argument('--stage', default='train', type=str)
parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str)
+ parser.add_argument('--version', default=2,type=int)
parser.add_argument('--gpus', default = 2, type=int)
parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str)
parser.add_argument('--fold', default = 0)
parser.add_argument('--bag_size', default = 1024, type=int)
+ parser.add_argument('--resume_training', action='store_true')
+ # parser.add_argument('--ckpt_path', default = , type=str)
+
+
args = parser.parse_args()
return args
@@ -39,9 +44,10 @@ def main(cfg):
#---->load loggers
cfg.load_loggers = load_loggers(cfg)
# print(cfg.load_loggers)
+ save_path = Path(cfg.load_loggers[0].log_dir)
#---->load callbacks
- cfg.callbacks = load_callbacks(cfg)
+ cfg.callbacks = load_callbacks(cfg, save_path)
#---->Define Data
# DataInterface_dict = {'train_batch_size': cfg.Data.train_dataloader.batch_size,
@@ -58,11 +64,12 @@ def main(cfg):
'batch_size': cfg.Data.train_dataloader.batch_size,
'num_workers': cfg.Data.train_dataloader.num_workers,
'n_classes': cfg.Model.n_classes,
- 'backbone': cfg.Model.backbone,
'bag_size': cfg.Data.bag_size,
}
- dm = MILDataModule(**DataInterface_dict)
+ if cfg.Data.cross_val:
+ dm = CrossVal_MILDataModule(**DataInterface_dict)
+ else: dm = MILDataModule(**DataInterface_dict)
#---->Define Model
@@ -82,9 +89,9 @@ def main(cfg):
callbacks=cfg.callbacks,
max_epochs= cfg.General.epochs,
min_epochs = 200,
- # gpus=cfg.General.gpus,
- gpus = [2,3],
- strategy='ddp',
+ gpus=cfg.General.gpus,
+ # gpus = [0,2],
+ # strategy='ddp',
amp_backend='native',
# amp_level=cfg.General.amp_level,
precision=cfg.General.precision,
@@ -96,12 +103,31 @@ def main(cfg):
)
#---->train or test
+ if cfg.resume_training:
+ last_ckpt = log_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.version}' / 'last.ckpt'
+ trainer.fit(model = model, datamodule = dm, ckpt_path=last_ckpt)
+
if cfg.General.server == 'train':
- trainer.fit_loop = KFoldLoop(3, trainer.fit_loop, )
- trainer.fit(model = model, datamodule = dm)
+
+ # k-fold cross validation loop
+ if cfg.Data.cross_val:
+ internal_fit_loop = trainer.fit_loop
+ trainer.fit_loop = KFoldLoop(cfg.Data.nfold, export_path = cfg.log_path, **ModelInterface_dict)
+ trainer.fit_loop.connect(internal_fit_loop)
+ trainer.fit(model, dm)
+ else:
+ trainer.fit(model = model, datamodule = dm)
else:
- model_paths = list(cfg.log_path.glob('*.ckpt'))
+ log_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.version}'
+
+ test_path = Path(log_path) / 'test'
+ for n in range(cfg.Model.n_classes):
+ n_output_path = test_path / str(n)
+ n_output_path.mkdir(parents=True, exist_ok=True)
+ # print(cfg.log_path)
+ model_paths = list(log_path.glob('*.ckpt'))
model_paths = [str(model_path) for model_path in model_paths if 'epoch' in str(model_path)]
+ # model_paths = [f'{log_path}/epoch=279-val_loss=0.4009.ckpt']
for path in model_paths:
print(path)
new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg)
@@ -120,6 +146,16 @@ if __name__ == '__main__':
cfg.Data.fold = args.fold
cfg.Loss.base_loss = args.loss
cfg.Data.bag_size = args.bag_size
+ cfg.version = args.version
+
+ log_path = Path(cfg.General.log_path) / str(Path(cfg.config).parent)
+ Path(cfg.General.log_path).mkdir(exist_ok=True, parents=True)
+ log_name = f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
+ task = '_'.join(Path(cfg.config).name[:-5].split('_')[2:])
+ # task = Path(cfg.config).name[:-5].split('_')[2:][0]
+ cfg.log_path = log_path / f'{cfg.Model.name}' / task / log_name
+
+
#---->main
diff --git a/train_loop.py b/train_loop.py
new file mode 100644
index 0000000..f923681
--- /dev/null
+++ b/train_loop.py
@@ -0,0 +1,212 @@
+from pytorch_lightning import LightningModule
+import torch
+import torch.nn.functional as F
+from torchmetrics.classification.accuracy import Accuracy
+import os.path as osp
+from abc import ABC, abstractmethod
+from copy import deepcopy
+from pytorch_lightning import LightningModule
+from pytorch_lightning.loops.base import Loop
+from pytorch_lightning.loops.fit_loop import FitLoop
+from pytorch_lightning.trainer.states import TrainerFn
+from datasets.data_interface import BaseKFoldDataModule
+from typing import Any, Dict, List, Optional, Type
+import torchmetrics
+import pandas as pd
+import matplotlib.pyplot as plt
+import seaborn as sns
+
+
+
+class EnsembleVotingModel(LightningModule):
+ def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str], n_classes, log_path) -> None:
+ super().__init__()
+ # Create `num_folds` models with their associated fold weights
+ self.n_classes = n_classes
+ self.log_path = log_path
+ self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+ self.models = torch.nn.ModuleList([model_cls.load_from_checkpoint(p) for p in checkpoint_paths])
+ self.test_acc = Accuracy()
+ if self.n_classes > 2:
+ self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted')
+ metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes,
+ average='micro'),
+ torchmetrics.CohenKappa(num_classes = self.n_classes),
+ torchmetrics.F1Score(num_classes = self.n_classes,
+ average = 'macro'),
+ torchmetrics.Recall(average = 'macro',
+ num_classes = self.n_classes),
+ torchmetrics.Precision(average = 'macro',
+ num_classes = self.n_classes),
+ torchmetrics.Specificity(average = 'macro',
+ num_classes = self.n_classes)])
+
+ else :
+ self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'weighted')
+ metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2,
+ average = 'micro'),
+ torchmetrics.CohenKappa(num_classes = 2),
+ torchmetrics.F1Score(num_classes = 2,
+ average = 'macro'),
+ torchmetrics.Recall(average = 'macro',
+ num_classes = 2),
+ torchmetrics.Precision(average = 'macro',
+ num_classes = 2)])
+ self.test_metrics = metrics.clone(prefix = 'test_')
+ self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes)
+
+ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
+ # Compute the averaged predictions over the `num_folds` models.
+ # print(batch[0].shape)
+ input, label, _ = batch
+ label = label.float()
+ input = input.squeeze(0).float()
+
+
+ logits = torch.stack([m(input) for m in self.models]).mean(0)
+ Y_hat = torch.argmax(logits, dim=1)
+ Y_prob = F.softmax(logits, dim = 1)
+ # #---->acc log
+ Y = torch.argmax(label)
+ self.data[Y]["count"] += 1
+ self.data[Y]["correct"] += (Y_hat.item() == Y)
+
+ return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label}
+
+ def test_epoch_end(self, output_results):
+ probs = torch.cat([x['Y_prob'] for x in output_results])
+ max_probs = torch.stack([x['Y_hat'] for x in output_results])
+ # target = torch.stack([x['label'] for x in output_results], dim = 0)
+ target = torch.cat([x['label'] for x in output_results])
+ target = torch.argmax(target, dim=1)
+
+ #---->
+ auc = self.AUROC(probs, target.squeeze())
+ metrics = self.test_metrics(max_probs.squeeze() , target)
+
+
+ # metrics = self.test_metrics(max_probs.squeeze() , torch.argmax(target.squeeze(), dim=1))
+ metrics['test_auc'] = auc
+
+ # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True)
+
+ # print(max_probs.squeeze(0).shape)
+ # print(target.shape)
+ # self.log_dict(metrics, logger = True)
+ for keys, values in metrics.items():
+ print(f'{keys} = {values}')
+ metrics[keys] = values.cpu().numpy()
+ #---->acc log
+ for c in range(self.n_classes):
+ count = self.data[c]["count"]
+ correct = self.data[c]["correct"]
+ if count == 0:
+ acc = None
+ else:
+ acc = float(correct) / count
+ print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
+ self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+
+ self.log_confusion_matrix(probs, target, stage='test')
+ #---->
+ result = pd.DataFrame([metrics])
+ result.to_csv(self.log_path / 'result.csv')
+
+
+ def log_confusion_matrix(self, max_probs, target, stage):
+ confmat = self.confusion_matrix(max_probs.squeeze(), target)
+ df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
+ plt.figure()
+ fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
+ # plt.close(fig_)
+ # plt.savefig(f'{self.log_path}/cm_e{self.current_epoch}')
+ self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+
+ if stage == 'test':
+ plt.savefig(f'{self.log_path}/cm_test')
+ plt.close(fig_)
+
+class KFoldLoop(Loop):
+ def __init__(self, num_folds: int, export_path: str, **kargs) -> None:
+ super().__init__()
+ self.num_folds = num_folds
+ self.current_fold: int = 0
+ self.export_path = export_path
+ self.n_classes = kargs["model"].n_classes
+ self.log_path = kargs["log"]
+
+ @property
+ def done(self) -> bool:
+ return self.current_fold >= self.num_folds
+
+ def connect(self, fit_loop: FitLoop) -> None:
+ self.fit_loop = fit_loop
+
+ def reset(self) -> None:
+ """Nothing to reset in this loop."""
+
+ def on_run_start(self, *args: Any, **kwargs: Any) -> None:
+ """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the
+ model."""
+ assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
+ self.trainer.datamodule.setup_folds(self.num_folds)
+ self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict())
+
+ def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
+ """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance."""
+ print(f"STARTING FOLD {self.current_fold}")
+ assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
+ self.trainer.datamodule.setup_fold_index(self.current_fold)
+
+ def advance(self, *args: Any, **kwargs: Any) -> None:
+ """Used to the run a fitting and testing on the current hold."""
+ self._reset_fitting() # requires to reset the tracking stage.
+ self.fit_loop.run()
+
+ self._reset_testing() # requires to reset the tracking stage.
+ self.trainer.test_loop.run()
+ self.current_fold += 1 # increment fold tracking number.
+
+ def on_advance_end(self) -> None:
+ """Used to save the weights of the current fold and reset the LightningModule and its optimizers."""
+ self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt"))
+ # restore the original weights + optimizers and schedulers.
+ self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict)
+ self.trainer.strategy.setup_optimizers(self.trainer)
+ self.replace(fit_loop=FitLoop)
+
+ def on_run_end(self) -> None:
+ """Used to compute the performance of the ensemble model on the test set."""
+ checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)]
+ voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths, n_classes=self.n_classes, log_path=self.log_path)
+ voting_model.trainer = self.trainer
+ # This requires to connect the new model and move it the right device.
+ self.trainer.strategy.connect(voting_model)
+ self.trainer.strategy.model_to_device()
+ self.trainer.test_loop.run()
+
+ def on_save_checkpoint(self) -> Dict[str, int]:
+ return {"current_fold": self.current_fold}
+
+ def on_load_checkpoint(self, state_dict: Dict) -> None:
+ self.current_fold = state_dict["current_fold"]
+
+ def _reset_fitting(self) -> None:
+ self.trainer.reset_train_dataloader()
+ self.trainer.reset_val_dataloader()
+ self.trainer.state.fn = TrainerFn.FITTING
+ self.trainer.training = True
+
+ def _reset_testing(self) -> None:
+ self.trainer.reset_test_dataloader()
+ self.trainer.state.fn = TrainerFn.TESTING
+ self.trainer.testing = True
+
+ def __getattr__(self, key) -> Any:
+ # requires to be overridden as attributes of the wrapped loop are being accessed.
+ if key not in self.__dict__:
+ return getattr(self.fit_loop, key)
+ return self.__dict__[key]
+
+ def __setstate__(self, state: Dict[str, Any]) -> None:
+ self.__dict__.update(state)
\ No newline at end of file
diff --git a/utils/__pycache__/utils.cpython-39.pyc b/utils/__pycache__/utils.cpython-39.pyc
index 704aade8ba3e8bcbedf237b3fde5db80b84fcd32..33d3d005a578948dd3d5b58938a815f86c4e92f5 100644
GIT binary patch
delta 2565
zcmew*wN#!jk(ZZ?fq{Xc{>JR2w;~hyWa=+5GBBhtq%h_%<T6GvGJ@DlIZV0CQOvn4
zQ7m9Sa}H}1YYIaOa}HZBdlUy)j3tLNmn(`3%x2Bu&gF^X1+&?5_;UH9_`z)U9D!WH
zC_ylrBS$D#I7&EIBuWI#=gbk!6^jyMtOr@hl_Q=j5hal;86}x36(t2$z?~zWD;*^b
zX7l98<jO|Lg4w(|a=G$R@?bV!jzX?tlp>hTpQDtk9HpGA5~Y%>8l{@67NwS}9;FVJ
zW6sft(nw)Q5y;U5$ulyfsHF&|2(>UqY1OANqzLDT=W0jkFfycwq=>dKMCqn5q=@C{
z<?2W2=Nd#AfQ=B(G0ZiJG6J(Da*T6LqD&YWQY2HPS{R~C!C?anBePs{CI&_ZXoyIs
z$h0s-S)?$e$U-4=irgHw6lRc~6y}tA*%XKXNP&EcLJBh)UlGb@PLWN)F0TYOP<akp
zlx2!aifRi(lvN6IFoUN0OHfQ}GTvedNG!?FWV*%d=;VA$I5DZXq$n}DBsnLsxHz{y
zwIm}y#ZQy*7OS^geol%e<1MbV(!Au7%>2B>98JbsJVl9lDfzka#RWN;B_LUs#F9jx
z{KS;hB2A`S>>yroYRN6O;L@bxRFDD=$K>SFqQvA%P3Bv|KAGtmC3%^7=^%B!`6;D2
zskiuxQ&UsoQ_E6|DoZlzGxO4Kv4<2TX6B_9X)@m8hwFf-207m+Kfgee@fN#VW(kzZ
z3Dw}1r^$MY*)gy37K=+}a>*?gpUmQtTO9rc5a-`w38^ed)nqIZVPIfLW&{NX6f-k0
zFmNz1FgQ<+W2|M9Vqjn>5}EvvQHoJ<@*hUGdNBqDhGI65T1Gxb0mdR_1_lO@3v%+)
z(^HFzF^mE!bq47vV_;yYVaQ@gVa#SO(kfxhVya=tVoqU7VNPM`Wv*o`VO+oh5?RPt
z%amtQ!U|$HGo~;~Gt@A|GsDz7rZ9rFu=cWmv}A#`q%fwi^)l5mm9V9-gS5?Nn9J16
z$jDH`Si_Xg6wIK>;dhIrB(=CC#7&cF@;hdE6Xsj&WvNBQnfZA|Y9Jq~gMCq?!N9;!
z#iy&Qt83?zpRSOSS!AW4$$pEqxFj(>b@Fr;1$B;`{Pg&O#FC6#oX){vK44E5F)=VO
z++r;#%FHX#WWL3moR&WM4U2@UBFJPgp#)OElarsA5)Tb0Sx}J3fr62fi;07Y<sTah
zAF}`>2O|?Bh|R&s^q-AIfVs$EawKbDJ;(x(S`Y>WC=V!pKqhQqU|>jRsAZ}F2UIO{
z4RaPl3KJ-VdYKp*N*GgEQa}O81P*An6!u=$T9z871uQiz3mF+1N?1$SQaDmLQ#eyN
zQn<j7yO%wMr<bvo6|9OCq^gFYK8q=fJ%u-$sb~X8xP&8xrG&GFxtTGA4{RBK3R?<O
zia;+1SUJ-ImJ*H>L8uBLs0v|Z6~PReB7V2HlQUA2vkUSw^Gb?CG?{L3q*i1Wm&E61
z-(oLFEK1EQDZa%~mY5TtGr5<|no)A{Nw!^*R-lk%PsuC-#{oxLW=?8eVs2`Y?c_P^
z(u|Umcd$DIa@}Ih%}&WIDl!2pc2CVqElSKOvH|hf5(^4a^HPe8KyhOX7Aa0GDoZUY
zG6Zp$i%W_$*^2Z*Y|i4uvebBxJ(F!Y3=Ew>I>50BCcsh32aeL@#GIU@#N_N^a5@iT
zU|?WiVqxN#JdZ<1or!~yjgbife{yhOiWEgo{>brI3}G&pO-^QUNpgN}fnDX~SDZy6
zASq3zTdXDdMadbrm~-;ei()4iaEa)EB#J<8DM|#TU631-Km;g9YqH&9ElSKwPrb#Q
zlbTnQJb43`R9pf`HjROSA&RvElo)TZ<m49@7lF*tWG({Broxnf4ZbCuT$Ep29G{w3
zQj}j%84q?s1t{@>N(lx=DMl4WF~*`?kS4Fm8r*BRL1uvsEK-_$pIfHB1Ed(7gqezJ
z!5Pg@(;5`VB^f!HNs!#boS2kc1S)!pSU?6PgB&OZwyG=>luh+YiW2jR)AEaQi*Iou
zu^B^di5F+*q$U>S>E-69q~^pIr{<)B%jhCjP%Ltk)PurIFE6hMlyi!rL4pEcll4kV
zGINUcQW8rNi`YTJ(&)lqIZzy@78QYfdrJbVisICgB2JLAK@EsoDquCasU=03$;Eof
zIrWLf#hGcD$%&wnS1%D-1r%|EY!oc01gnnEfz<{@puAng3o?xlMDT-@$zUkc%gN6#
zDAr3#EKV&F04Y(xr6jFBvm_p*x=0X*x)MkgqgPy#Sdv;?Bm`2%R#I7znU`K93~~dg
zjwliZu|TDNkvNDY1BxNG#FUiG<PvZtP^1D9)&da*AQPlv4w~G-Bg_uYUPVTet9dNg
zz}Ze~as!Xd<f}aVY@l3P6gl}lk4z0D6;*-E03{SiB0@?|Aaglva>13Q9Vm$xyD=~@
y@Gyc3P98=W6kz0G5`pt%m^m1E7`fPZ7&(*#K%z_>#vF`7j2s*SjBxmijS~P`KaFPq
delta 1941
zcmZ1~|4WK5k(ZZ?fq{YH?7=BX?LrgzWa^n185mL+QW$d>av7r-89{8O9HuCy6owS$
z9Ohh>C>BPB6v-5p6xJ5TDAp8)6t*0;T=pmqMursj6pj{#D9#jy6wVy3T<$3DT%IT%
zunAl_yt#Z)d|)<r4u7sdlmH_`3Qr1e3qzD(3PTE0j!={kn64Mj6=7mvWXKhb5@TdY
z;Y;CfVTck>VMq~xLZ%eKIczB)=P;!(r3j=z_#k<q6yX$*7)Tz*7lF!yRAQGG1)C=}
zhb>AXMLb19p@ktzGKDFaK~w4_C}1=hZ?Oa<mSp&W7%rK~C6oD>Co8fsFfbG`GcYh{
zGT&lND@ZKKxW$r|nUi{pxiU9rB|{PK<U`DEf*@7JAcBFBk5PcJh<~yqi*!9mfSrMX
z!I^=9q1b?tfuV*Wiy@1#ma&AXhOwD3i#dfcg)xPxmx+;~gr$V3h9Qf!nK6s4nK6aA
zh9QeRi(@v!Tqbb_Fpo2f3&N{qf~(|CVMt+VWs+p51*zvrVa;YPItMYkhDn;CzJ^(v
zA%(4mA&WPQ52RU|p@t!zv4lTMpoX!8rG}}QQJkTM8CgXNBPeQnS!-EJ__G9Ys^aWr
zu4T>hLRba1i51QKdYA>Qc#LID0UK5#l)_cRki`gcW-}urLk&v}YdR=yxcx#_GTq`x
zt;j4ciO<iz#a@tDl$uvke2X#u<?sLh|KDOMNi8n9#gdkvlj5eya*HJ|F*o%Vds%8x
zF>_{q-sBT3^7Twb5};U=WME*Z;?vdD)wT1<Pgh9EEV5G26u!moo|>0hl$djiBPTyS
z9u&hxObiSRw^(ztQ!<NgF&CE<-C|45&r8cpzr_Zz{T8cleoAW2Esmtb<m{yUywqDP
zKKaGPw>Xm$i&Nus@{5bXHr`?_E=f#J<-Wz@l30>hB+kIVFj<;ap`Oz@ILs$MJw3JP
z7HdIKW?qRV+b!ngwDclbkiB4MfP4|f1M)>?eqKD7$61mTp99tv#g$xK24xrVg1iPM
z_&~aOa`F>XpmN0^;bJ{d;%4Gt6k^n26k*|FR8VH*W8`CG`Nzh>$H@1egINs3gYsGa
zv#|&;7wJvj&l*_I2+|3~tPBhcJg~I9je&t7ouQTyl%8vuYM8P>=@^!hnNpZjSZWxu
zn6p??7*kkN*m_xOnQIsqu+}gyWMpJ0VJl$=C9)Kb6!sKOFy!iGPvP!mtYrbKVgacF
z>5t9gNa4w5D%u1RF5ygJE&(MGhGxbTUU0_YO93S;{$38S2F3-fC7dY&a5aKZH9{$j
zDR4Ex44T4zAw`0qm|$mjVq~4{#lB0@6cljmDVar}$iKw_%HE(5FS3{{%OTCkI@yxL
z!Ic{vIYn}y&^7`QG9bbnq?0YNpdd9brN{szVh9!~PAw`+Eh^FniO7Qr57r_bke#PE
z3=C~Sf}ogzL>)NR_`tE2oS2gXN>RnepaLTd6!J_gOl*_2Id#+-IT$$@nGo<N2M4A|
zQNZLn&c|Y4!@=fq+2mvvmn7%s7T9G>F6Sx=0ZC~x-C`}tFG|k1#hjC$UK9dyAJZ+)
ziumIEw36J!id&482&WZAFfcG=f}9))B0xz(lkFC3QDR<t>MiD+)V!jo$+p~5#^E5@
zSOx}$DAo#4^1H<X&IVxnn2W%&F)$@XV7CcRZsZb{NC$b91LRdkDMl4WF~*`4kTTcF
zTe;V8gLFfkGC7Aw#=IP)7?e2~n2Jh3&Svz}lmZnAB^f!HN#JZ-1Trg%Eiok}Gr0s@
za2N4Up35uFEdVlB14QUeF6UEZ1ILZQ<avA+jE0l1^BHi1V>JMzAZju%zfANkuC&s;
z<dV$%yu_TMAdooNWrZLX$a09okempLc@CT0{FKt1R69^W6@#qfVdh}uVdP@tVdPK}
QVB}%sVd5~FJds}n0E2$Q-2eap
diff --git a/utils/utils.py b/utils/utils.py
index 010eaab..814cf1d 100755
--- a/utils/utils.py
+++ b/utils/utils.py
@@ -10,10 +10,11 @@ from torch.utils.data.dataset import Dataset, Subset
from torchmetrics.classification.accuracy import Accuracy
from pytorch_lightning import LightningDataModule, seed_everything, Trainer
-from pytorch_lightning.core.module import LightningModule
+from pytorch_lightning import LightningModule
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.trainer.states import TrainerFn
+from typing import Any, Dict, List, Optional, Type
#---->read yaml
import yaml
@@ -27,30 +28,30 @@ def read_yaml(fpath=None):
from pytorch_lightning import loggers as pl_loggers
def load_loggers(cfg):
- log_path = cfg.General.log_path
- Path(log_path).mkdir(exist_ok=True, parents=True)
- log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
- version_name = Path(cfg.config).name[:-5]
+ # log_path = cfg.General.log_path
+ # Path(log_path).mkdir(exist_ok=True, parents=True)
+ # log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
+ # version_name = Path(cfg.config).name[:-5]
#---->TensorBoard
if cfg.stage != 'test':
- cfg.log_path = Path(log_path) / log_name / version_name / f'fold{cfg.Data.fold}'
- tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name),
- name = version_name, version = f'fold{cfg.Data.fold}',
+
+ tb_logger = pl_loggers.TensorBoardLogger(cfg.log_path,
+ # version = f'fold{cfg.Data.fold}'
log_graph = True, default_hp_metric = False)
#---->CSV
- csv_logger = pl_loggers.CSVLogger(log_path+str(log_name),
- name = version_name, version = f'fold{cfg.Data.fold}', )
+ csv_logger = pl_loggers.CSVLogger(cfg.log_path,
+ ) # version = f'fold{cfg.Data.fold}',
else:
- cfg.log_path = Path(log_path) / log_name / version_name / f'test'
- tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name),
- name = version_name, version = f'test',
+ cfg.log_path = Path(cfg.log_path) / f'test'
+ tb_logger = pl_loggers.TensorBoardLogger(cfg.log_path,
+ version = f'test',
log_graph = True, default_hp_metric = False)
#---->CSV
- csv_logger = pl_loggers.CSVLogger(log_path+str(log_name),
- name = version_name, version = f'test', )
-
+ csv_logger = pl_loggers.CSVLogger(cfg.log_path,
+ version = f'test', )
+
print(f'---->Log dir: {cfg.log_path}')
@@ -63,11 +64,11 @@ from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
-def load_callbacks(cfg):
+def load_callbacks(cfg, save_path):
Mycallbacks = []
# Make output path
- output_path = cfg.log_path
+ output_path = save_path / 'checkpoints'
output_path.mkdir(exist_ok=True, parents=True)
early_stop_callback = EarlyStopping(
@@ -94,8 +95,9 @@ def load_callbacks(cfg):
Mycallbacks.append(progress_bar)
if cfg.General.server == 'train' :
+ # save_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.resume_version}' / last.ckpt
Mycallbacks.append(ModelCheckpoint(monitor = 'val_loss',
- dirpath = str(cfg.log_path),
+ dirpath = str(output_path),
filename = '{epoch:02d}-{val_loss:.4f}',
verbose = True,
save_last = True,
@@ -103,7 +105,7 @@ def load_callbacks(cfg):
mode = 'min',
save_weights_only = True))
Mycallbacks.append(ModelCheckpoint(monitor = 'val_auc',
- dirpath = str(cfg.log_path),
+ dirpath = str(output_path),
filename = '{epoch:02d}-{val_auc:.4f}',
verbose = True,
save_last = True,
@@ -136,87 +138,3 @@ def convert_labels_for_task(task, label):
return label_map[task][label]
-#-----> KFOLD LOOP
-
-class KFoldLoop(Loop):
- def __init__(self, num_folds: int, export_path: str) -> None:
- super().__init__()
- self.num_folds = num_folds
- self.current_fold: int = 0
- self.export_path = export_path
-
- @property
- def done(self) -> bool:
- return self.current_fold >= self.num_folds
-
- def connect(self, fit_loop: FitLoop) -> None:
- self.fit_loop = fit_loop
-
- def reset(self) -> None:
- """Nothing to reset in this loop."""
-
- def on_run_start(self, *args: Any, **kwargs: Any) -> None:
- """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the
- model."""
- assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
- self.trainer.datamodule.setup_folds(self.num_folds)
- self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict())
-
- def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
- """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance."""
- print(f"STARTING FOLD {self.current_fold}")
- assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
- self.trainer.datamodule.setup_fold_index(self.current_fold)
-
- def advance(self, *args: Any, **kwargs: Any) -> None:
- """Used to the run a fitting and testing on the current hold."""
- self._reset_fitting() # requires to reset the tracking stage.
- self.fit_loop.run()
-
- self._reset_testing() # requires to reset the tracking stage.
- self.trainer.test_loop.run()
- self.current_fold += 1 # increment fold tracking number.
-
- def on_advance_end(self) -> None:
- """Used to save the weights of the current fold and reset the LightningModule and its optimizers."""
- self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt"))
- # restore the original weights + optimizers and schedulers.
- self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict)
- self.trainer.strategy.setup_optimizers(self.trainer)
- self.replace(fit_loop=FitLoop)
-
- def on_run_end(self) -> None:
- """Used to compute the performance of the ensemble model on the test set."""
- checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)]
- voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths)
- voting_model.trainer = self.trainer
- # This requires to connect the new model and move it the right device.
- self.trainer.strategy.connect(voting_model)
- self.trainer.strategy.model_to_device()
- self.trainer.test_loop.run()
-
- def on_save_checkpoint(self) -> Dict[str, int]:
- return {"current_fold": self.current_fold}
-
- def on_load_checkpoint(self, state_dict: Dict) -> None:
- self.current_fold = state_dict["current_fold"]
-
- def _reset_fitting(self) -> None:
- self.trainer.reset_train_dataloader()
- self.trainer.reset_val_dataloader()
- self.trainer.state.fn = TrainerFn.FITTING
- self.trainer.training = True
-
- def _reset_testing(self) -> None:
- self.trainer.reset_test_dataloader()
- self.trainer.state.fn = TrainerFn.TESTING
- self.trainer.testing = True
-
- def __getattr__(self, key) -> Any:
- # requires to be overridden as attributes of the wrapped loop are being accessed.
- if key not in self.__dict__:
- return getattr(self.fit_loop, key)
- return self.__dict__[key]
-
- def __setstate__(self, state: Dict[str, Any]) -> None:
- self.__dict__.update(state)
\ No newline at end of file
--
GitLab