From d7fe105eceef23555df224fd257187c17ba64430 Mon Sep 17 00:00:00 2001 From: Ycblue <yuchialan@gmail.com> Date: Thu, 23 Jun 2022 10:39:02 +0200 Subject: [PATCH] added jpg_dataloader, rm fixed_sized_bag --- .gitignore | 4 +- DeepGraft/AttMIL_resnet18_debug.yaml | 51 ++ DeepGraft/AttMIL_simple_no_other.yaml | 49 ++ DeepGraft/AttMIL_simple_no_viral.yaml | 48 ++ DeepGraft/AttMIL_simple_tcmr_viral.yaml | 49 ++ DeepGraft/TransMIL_efficientnet_no_other.yaml | 4 +- DeepGraft/TransMIL_efficientnet_no_viral.yaml | 6 +- .../TransMIL_efficientnet_tcmr_viral.yaml | 8 +- ...ebug.yaml => TransMIL_resnet18_debug.yaml} | 7 +- DeepGraft/TransMIL_resnet50_tcmr_viral.yaml | 7 +- __pycache__/train_loop.cpython-39.pyc | Bin 0 -> 9235 bytes .../custom_dataloader.cpython-39.pyc | Bin 15301 -> 15363 bytes .../custom_jpg_dataloader.cpython-39.pyc | Bin 0 -> 11037 bytes .../__pycache__/data_interface.cpython-39.pyc | Bin 7011 -> 10339 bytes datasets/custom_dataloader.py | 77 +-- datasets/custom_jpg_dataloader.py | 459 ++++++++++++++++++ datasets/data_interface.py | 107 +++- models/AttMIL.py | 79 +++ models/TransMIL.py | 40 +- models/__pycache__/AttMIL.cpython-39.pyc | Bin 0 -> 1551 bytes models/__pycache__/TransMIL.cpython-39.pyc | Bin 3328 -> 3233 bytes .../model_interface.cpython-39.pyc | Bin 11282 -> 14054 bytes models/model_interface.py | 249 ++++++++-- test_visualize.py | 148 ++++++ train.py | 58 ++- train_loop.py | 212 ++++++++ utils/__pycache__/utils.cpython-39.pyc | Bin 3450 -> 4005 bytes utils/utils.py | 126 +---- 28 files changed, 1547 insertions(+), 241 deletions(-) create mode 100644 DeepGraft/AttMIL_resnet18_debug.yaml create mode 100644 DeepGraft/AttMIL_simple_no_other.yaml create mode 100644 DeepGraft/AttMIL_simple_no_viral.yaml create mode 100644 DeepGraft/AttMIL_simple_tcmr_viral.yaml rename DeepGraft/{TransMIL_debug.yaml => TransMIL_resnet18_debug.yaml} (91%) create mode 100644 __pycache__/train_loop.cpython-39.pyc create mode 100644 datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc create mode 100644 datasets/custom_jpg_dataloader.py create mode 100644 models/AttMIL.py create mode 100644 models/__pycache__/AttMIL.cpython-39.pyc create mode 100644 test_visualize.py create mode 100644 train_loop.py diff --git a/.gitignore b/.gitignore index 4cf8dd1..c9e4fda 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ -logs/* \ No newline at end of file +logs/* +lightning_logs/* +test/* \ No newline at end of file diff --git a/DeepGraft/AttMIL_resnet18_debug.yaml b/DeepGraft/AttMIL_resnet18_debug.yaml new file mode 100644 index 0000000..03ebd7e --- /dev/null +++ b/DeepGraft/AttMIL_resnet18_debug.yaml @@ -0,0 +1,51 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [1] + epochs: &epoch 1 + grad_acc: 2 + frozen_bn: False + patience: 2 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_debug.json' + fold: 0 + nfold: 2 + cross_val: False + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + + + +Model: + name: AttMIL + n_classes: 2 + backbone: simple + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/AttMIL_simple_no_other.yaml b/DeepGraft/AttMIL_simple_no_other.yaml new file mode 100644 index 0000000..ae90a80 --- /dev/null +++ b/DeepGraft/AttMIL_simple_no_other.yaml @@ -0,0 +1,49 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 200 + grad_acc: 2 + frozen_bn: False + patience: 20 + server: train #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_other.json' + fold: 1 + nfold: 3 + cross_val: False + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: AttMIL + n_classes: 5 + backbone: simple + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/AttMIL_simple_no_viral.yaml b/DeepGraft/AttMIL_simple_no_viral.yaml new file mode 100644 index 0000000..37ee074 --- /dev/null +++ b/DeepGraft/AttMIL_simple_no_viral.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 50 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: AttMIL + n_classes: 4 + backbone: simple + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/AttMIL_simple_tcmr_viral.yaml b/DeepGraft/AttMIL_simple_tcmr_viral.yaml new file mode 100644 index 0000000..c982d3a --- /dev/null +++ b/DeepGraft/AttMIL_simple_tcmr_viral.yaml @@ -0,0 +1,49 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 300 + grad_acc: 2 + frozen_bn: False + patience: 20 + server: train #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json' + fold: 1 + nfold: 3 + cross_val: False + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: AttMIL + n_classes: 2 + backbone: simple + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_efficientnet_no_other.yaml b/DeepGraft/TransMIL_efficientnet_no_other.yaml index 79d8ea8..7687a0c 100644 --- a/DeepGraft/TransMIL_efficientnet_no_other.yaml +++ b/DeepGraft/TransMIL_efficientnet_no_other.yaml @@ -6,10 +6,10 @@ General: precision: 16 multi_gpu_mode: dp gpus: [0] - epochs: &epoch 1000 + epochs: &epoch 200 grad_acc: 2 frozen_bn: False - patience: 200 + patience: 20 server: test #train #test log_path: logs/ diff --git a/DeepGraft/TransMIL_efficientnet_no_viral.yaml b/DeepGraft/TransMIL_efficientnet_no_viral.yaml index 8780060..98fe377 100644 --- a/DeepGraft/TransMIL_efficientnet_no_viral.yaml +++ b/DeepGraft/TransMIL_efficientnet_no_viral.yaml @@ -5,11 +5,11 @@ General: amp_level: O2 precision: 16 multi_gpu_mode: dp - gpus: [3] - epochs: &epoch 500 + gpus: [0, 2] + epochs: &epoch 200 grad_acc: 2 frozen_bn: False - patience: 200 + patience: 20 server: test #train #test log_path: logs/ diff --git a/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml index f69b5bf..5223032 100644 --- a/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml +++ b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml @@ -5,7 +5,7 @@ General: amp_level: O2 precision: 16 multi_gpu_mode: dp - gpus: [3] + gpus: [0] epochs: &epoch 500 grad_acc: 2 frozen_bn: False @@ -16,10 +16,11 @@ General: Data: dataset_name: custom data_shuffle: False - data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + data_dir: '/home/ylan/data/DeepGraft/256_256um/' label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json' fold: 1 - nfold: 4 + nfold: 3 + cross_val: True train_dataloader: batch_size: 1 @@ -33,6 +34,7 @@ Model: name: TransMIL n_classes: 2 backbone: efficientnet + in_features: 512 Optimizer: diff --git a/DeepGraft/TransMIL_debug.yaml b/DeepGraft/TransMIL_resnet18_debug.yaml similarity index 91% rename from DeepGraft/TransMIL_debug.yaml rename to DeepGraft/TransMIL_resnet18_debug.yaml index d83ce0d..29bfa1e 100644 --- a/DeepGraft/TransMIL_debug.yaml +++ b/DeepGraft/TransMIL_resnet18_debug.yaml @@ -6,10 +6,10 @@ General: precision: 16 multi_gpu_mode: dp gpus: [1] - epochs: &epoch 200 + epochs: &epoch 1 grad_acc: 2 frozen_bn: False - patience: 200 + patience: 2 server: test #train #test log_path: logs/ @@ -19,7 +19,8 @@ Data: data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' label_file: '/home/ylan/DeepGraft/training_tables/split_debug.json' fold: 0 - nfold: 4 + nfold: 2 + cross_val: True train_dataloader: batch_size: 1 diff --git a/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml index a756616..f6e4697 100644 --- a/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml +++ b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml @@ -5,8 +5,8 @@ General: amp_level: O2 precision: 16 multi_gpu_mode: dp - gpus: [3] - epochs: &epoch 500 + gpus: [0] + epochs: &epoch 200 grad_acc: 2 frozen_bn: False patience: 50 @@ -19,7 +19,8 @@ Data: data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' label_file: '/home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json' fold: 1 - nfold: 4 + nfold: 3 + cross_val: True train_dataloader: batch_size: 1 diff --git a/__pycache__/train_loop.cpython-39.pyc b/__pycache__/train_loop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65c5ca839312787cce6aed19f346ae40a7c04d83 GIT binary patch literal 9235 zcmYe~<>g{vU|=x&IVEYL1_Q%m5C<8vFfcGUFfcF_|6*WZNMT4}%wfo7jACR2v6*t1 zqL@+`QkZg>b6J=e7#VU|qu9VQ%sK4298nx#Hd78~E>{#cn9Y*Ilgk^$o68r)2j;Wp z@aGCd34qziCbQ)T<_bj#fyLQ#gmXorM8Is09MN2{C^0abGe<mEB1!_x=E{-Gm5P$e zm5!3mm5Gwcm5q|km5Y)C%Q5B1N6CX-sSu?Irj>G)qg0@Bs!?k03@O|xJS_|<JgJP$ z%u(v@3@N-Rd@T$qe5pds%uyOCOu-DA{4YVFqRDuR-zPIYqa-ggFWomkr8FniPm}Q$ zhhuVbX;ETwr6$uYW=ALITYQO0#U(|F$tAg|B^miCASEfOsRhaT1(lkNw^)4g^9wW? zZ?U^&mOz=DAw`Lqd8tKid76y3gq;$LQ@!2tb5dLqOA;a0XtLg7cFe21#p05gTyl%W zC$qTZ7KeX9NoIatV$Lm=kjjG8WRMGyF&mUq9KyiBkjfCnn8Fanl**jSBFT`-n#v~0 zkiwY4G>5sJA&oJGIfbQ#CyG6VH<crmGnF%iF@<doOA31mM+<8dR|;neR|`WFHzb^* zc-k3Q7@~NC88rEB3A^SMr{*T*q=w~}K*A(7$4`^_7H4jLN@`Aga!&Crf#i(T<m`g{ z%)FBLg2a-H;#+J*sU@XFdC9C$2ZAV&W4Rd^7??qc&4H1Dp@gBCVFBYp21bS&<{E~0 zrW%$ShInR>lUo={SZWxun3@?~7@8St7~)w=*s|Cca6mo6)WT821d>?@lH)AlYG$Zm zi020BNMXrlDLPfclf?_>v8J%~GSx7|^Oo=|;49(D;sT4ZgG5VsviQL~j$XzRo-6?{ zn-eM{n8FF>bM-Qn@MH<0@PrY(bhd?zwXAs|Ap25SQZQ^w0R=D{*bM#@mK3N>DJ&^S z_N1_+z-@uB(^)~L*05%YfTE)p<VVpGo-8r2i#Wi1@f3z&22EbSWJXZPLNOBq0|O|) zgFyi-#=yXk&QQY;%U8>o0?NybbC@PF6|w{~tYj(z>C$An#hRR7npbj*JvqOqC^flc zCBsUlTTFTew-_^v#2FYEz{D?8{fzwFRQ<}F#611-{G#mQg2d!h{g9%>ykcKZA6=K! z)B^XS#IzFq5>Nt<&&kg(&?~6C#buL|SzMBwpIcxj&A`C$8RW()11xDtFFrmqFS8^* zUe6{cKRGd{*iH|j6|4V*(f!xKuz;b4aUmnbe_&^8GTsu&$xlp)Pb<pLjYp(UO{OAA z1_p*(j0Gzhiv(ewECM;PNQQxd0fdWW85kH|f<p5aM|mo!Fi1_&WV*$jSDG82oRe5w zoLYQ~J+UmcC^0?t7He*1a#6k}W05k*Sk~M`5Ra)y1!Octs3f(xBwmy87F$74YFcK6 zro=7Q;?jcDqFWp=pWoulgB#3|lb;UGM7LOr67$kiZ?U8#mL%R{Ey*uR&bY;tmv@T` zQeJ=y2sUsgEG`1&sVEMxO^M0Lw|Ky+b5l!-GLwsMu{wqZ`8(ed@CEao^K)`ilR?GG zEiUK$jMO~u#DapvTkLLz!O8hWskhjIQj-&Na&B=36s0C-7H8(?-Qo@|NKMX6%S_HJ zsl3ISoRgoIdW#2a8O#cP=ls01QjlWb#FC=SidzE7FoF17u!ttxEtcZcoHR(F3xR@1 z8${@Uf{h(&&Pv8xobiyLh>tH)0Lg=6MvH-g0hFM@<s1Ve7ZV#J8>1Mb3?m047o!j( z4<plm7A77>4lvEe$nu|s35op9!7ISX1(xGtWCPPk<WDwEKE^5;d?^!>Us!IjCMA|6 zXWZfhQ}LN86}R|6kpoI+sYM`u5h!UUgKAVzatC2h-3!W-pa?2n!N9;!!x+Qd$xzEu z%UZ*-fMFp+ElUT(0>%`^g^VCxCqoI-0!Vqyl)~J?QNxhM0x9Sgu!7XovXwApvDL7F z3Y%W0TJ{q5EDku2ql9SzTMf%X##&I}#aY8C$&kXv${@**!eY+A%wWS%!;l3LWe17Y zurJ_R$WX(X%><QYtKq0&uVG2yn8Omxpvmb6_BKWi1o;l+&0-N)3Byy%Si{)C(2Sh_ zAej)B*>YDh-eS$nD<~}iizBj|6DW#7*{#YDTLgg9SaC^e0k*Q?C8!iz$#{!7B{R23 z6l4%6iD<IiV#~=-&nzjv#TFS~P?Vo^i#0MnBeCQbYffTPYK|uNE!MQ0{KS%5?8Swp zsi{?|Mf#v5$XZ;In4Dc?1Y)z~rY7dyVoNMa&rPhj#prg6y*NLuBsa05NDnNXS(2Ko z$yH<zveOwvxPS;(5aAC}z!+cT1LF9B2w$*BB*f$HAQ4cKzQqamn-!=y0|h1nGZ!NV zBOkK@BOfCl6BnZxqa0I}EZ%TQ29;8vbOyqF3=9k$3=9k)Cl$MaD_Bs<N?~eYr~y}` z3mI#fN*ER}r7)w2F_(Z+APb5ZO9?|3a}Bggt7WNSSin-lRKv1>bs@t-rdrk-h6QXj z%phJ3%R;7Fwi?zNwiI?7h8ng7><bxO7-F?zm}=Q;IZ8Nc*k?1OaMW<jW|+&|%*e=4 z!?A#KAp<z&vDL8KFcemmfGTbtP*or62Tr-X;FQZL$xsWbt|2@wNrqbP8m<)XY^I`n zDGaqdc|0Y2HQddNHC)mRwLB#pDLf1KYdC9oYPf5-7P5eBD%_C53-S}fd~j0&)G!bL z`6Y!fg};}Xk)cFzflv)wI^#mdTHYGo1;RB9St2O{(hLj1ZQx)AO+n0R7L*CV)hsC9 zz||}`6QnRmGC(RJO~xWn#DimfB`h0MfVmI?lnB5z4p<pD3z*@I>ePb#<c#>#ycBHt z0Ba5<<T!BB1XuFWpd<-qfeE;~z$!rQBFlVGK7e=>m-zu8dqL(Gd4gD8plkxI!Ag^> zm=$al5SbcW5>?5AD_@1`S}TRb<Ya~FS{(&w%SEBOR=>Iy!=a#p7aS0vf&m;$7>%JK zPy|98Dgkl^I1Io9JRovF;vk3KVgcn@O~zZGvLZdTimNEKxHP9kFS)o(QxsB)FefLL z6a|4afLi25CLmTIh_C^9j}=sK6yIVkD9X$$xy77ZP<o3suQa!yvPc)C23#M4N|+*0 z*xX`E%P-1JECID+(&N#q(_2ghDYrO5y$83V#N5<dY$f^e$;D+wpz2hU4_q|e;>#~B zDJU(8hd2{#Z+=qoEzaD;3UI+#e2c9lu_!&Y<Q8*cX>t*$6L^ayJGHX-7F$_jPHAfK zEyiR-9Nc0CwJg{mhCxCM9!|ITKp_(kPo|*KuNYKmGcbxUiZF>V@-YiA3b1gn3or^X z3Ni743OGg%Mm8oMMlnV)CIKcM#wsPksUM|gWoLl2cEP3HG*Hb7ZUurm6g5l>nQB=| zm_Q}i0+xjgpmM5&t%f0sy_pe4r!e+1*Rqy?OCv^x60R)n8rBq0$;i^nyntsRLoG-> zZwjj<Ll)lx{u<`l3@L1LnQB0y0t=ZzU6vH~Y^I`)5{?CeH4Iropi*uw6R6B6;aDJC z!v<o38|j>Wpq6%KUP@{OsHL4#nwwV~a*HFlAT_z9C@}{lnV6TCUvi5jIXAHYlF|kB z;q55}Nb{*Ej)8%pibFp+Hy#w!nj%GZpd<rMTj?NH28aN45Q|bkEO>&t#axh6a*Hi3 zGrhDZ^%iq+UhysVjMT)E+{A)gT<NJL@lYA|oc#3k)S}{BT&WcWsYRK&sd**0xDr!R zU<!DXON)w9^GYDGe2cv}u`D$$Gaa1bz>UJ<R84kp!Ycyjm|T!E;z7xt9h7)MDVa4T zEj~H-7E4-YdOTVRDoSGjbqE<4iUmMb8KVHB3L^)j2(ug$2cr;Ul?JxN0ct6Oj0U$m zVa-fNa0qJhfm+vjiMgrq@wd3*<8#41`S|!-Jn`{`rHMHZnIe9W8$nfQ5iiK0%*7=| zMWAF`1nQ^WV$RGfDT)FqDFG267vAE9HcyjNb8_O5J>dva0xmN^9so5*ia~WT2cw*d zm=K3X5l9piJs{t~F{l~=af$^PQTy+pPB2Qpoi&v$l|6-Z4r2;i3VRDnGh-A7xDU@U zhY8$wXNlrUVFb&AMmkucxKnvjc~Tftc;_&K>bVw{C|)EHffT_OmMA_X5up^}7M3V} zs2L(DqAe^@0#Ke<ig*i4lwhh*DoZMJ8j~ag*dB=#$rhF<;Z&g%#uTYJOexYSGA%4o zB2YcDDRM0=QKC?ue2PK~OO#kD4_K#Sic$+plsMQt9<X_gDav!0QdCk@TUeqbz+(@R z!3>&ew>ZILB%l!#@Q4Gb1D2MblTv((JGG)9zX&qg0BLlC8rYzq1YuCHf?^dEi^XQ( zgasdI0QIJsQa~hg4QOP5rG`0$Ns=Lp6*S6F!<@n_2^vda17}?pNE7%LE4Y_+i#aDh zU6Z{?6BN_nRMHPhG(50;0diT<1dt@SAP1MkERYNVQV5p1#hRU1lwMo}PA%}_6P)In zK+QlNkn<Tqy%bOnhljCB2p$6PezGPb#K(*vwNMOVgM17w3_*s}Fl2$6`;6JF!3--I z{WO`1CWBlCa>z<Xh@-&Pg9%Wwy~UD}pO@MNas<dA2Bs=Lgw-G^O{QBMX_+P94nD*I zV0BOjfY=}mb^xf;Qv(WO22d~-O=VzUSjkv4gMooTlL=xsh>cbPfC|UFywv29KDZkg zs{|2lNX`e#qEzxAZ-X${R#2>`Fa|SdGF8d><(Fh+=A|o?<SP`V7N?ddWacT9WMmdA zfQKjiAXbAa4G0Sn3ant|6Oo+2k8lE5PLuf-3n&h6v1ONoXh>LqTmcOV@BkMmIfB|i z#TMY?SHqCSn8j4W+{{!g<i*g$2;naP4RW!9Y6@_z#0VPVS_m2N;)MF7JhV791>_A- zpIjlKIJKm-ARd&`iW3w-LmLVu8L0{hScVf56f*OQOA_;vQxy{PQWT0y@{3Zzit~#y z(=+oDa}>(KLmR~k`Dq|=E(It6r*%C)O|Dy9nZ+<OZn1+$cT$UPae;cgkm9rm6wkM~ zk=#@SN?%26pll(K0~$C1jiN#e-&<V8C5a`e@hPC;O=)!5c!cCiW=I;D14^)SK?FE@ zzyvs@^5o~m7nSCLv=o(00Yw$aMh0dfMghhuF+_ABD?>>(pvVDXa0Efh2T+NS#W<T` zE)%HhP{LHhkj0$EQo`EIRICJzFy<_l1)%Dm5fowIA!Ft$4_wg`53cZujw8J)zTgnY zpb$?#cLg_p9~T8bP4=QBkl(@Cb{>c|A4Dtw5#VqaK!iKQvZAFRIdCw632;~m<mbgF zrj&uw0VKR;!^4XS)EKLhLJ2PvrMSZkl&EVMKm!(x;8JD*6L=JtVF5FQ&SJ@8g^95+ zWU-|%nlmsl#Ir#|%?cK3pu}5Lnx~Mckd|2j8eImbRZs&5lzH<(c{Wu6TD~e|<maU5 zfjbL+n(VjuK+Rfk^98ICT)Nz1E-KA~%Rr4O0)-p2yA8_ykb)3fSg}KWzYr8epfF}& z;$Y-r;$Y-KDTtt|P(lZk{Xx||xF7<x-a+N>0!Gka4rmAmoVr+NGo&!hWnKtwB!c9# z*cPzYFl2Fn(i51^xd7DR1$R7a7_zuacv6^qL7h<0a13iz94yp8EiX`(Ma#kPfC3c{ z;D7{|S74QBQyO5|%#vb-{DP9q+{~)fqGG*Ykg-I)U(9+1B}FqC7#RFCc_D$v53&(y z5cL*wesRGqmaP2DJaDZB3DQNNY|jrG3<IS!XihI$0uto_&6Sj-rdQsArdou(x7dqP z3vv>ZQ;R_1SA?xh=0i%usd*`@K+ywI%)q3;Bmm0u5?Ir4YF-LT9D#B?sGtDR#h}I< zcs7ACg{g(3grS2WixD&g(ac!OSi)4o+`#}!!_DA<tt_?<hAj3J=3bUsCXjdyQ#=cZ z&f-|WSp$wuD31%2`oU~&P#1!cAq5h1Zm^gG4TqPOK+55Q)FN;P4^%=!ieQCQ=uC?O zxC~Z+rhia6Rwzy_(L?bKsImnG0(dA+4K@@fSIY<*%;{jrVoG7CVN78ZXMm+T=(q%P z(K=8u0q%3J2i3^TMH@i^n?MAp^#U6qXH5f-Xd)s6RGxrGEVh7@f;z-i8Ylq|DwA=H zQD`z1!5c-OWCiY?fGZ17yZ9DMNo7GQIIR_dQV~LJ5!kL-pftn}4wU$k{P>jAvdrYv zqBS6AfD(fyw88)lDo62@fv2fJl}>67q;>~;0p$8yT+k{ZH7^C^-(pZpiGfK1G?2{1 z2g>Fm*gS+1?x36w3VU!)2Ne?Fyv>lw2%1>&1LuU5j9?0DS{BOBi$}{%dqC+KWF!M) zl?;mgXbLr%iWY<X42p7;fj^LD5C*#mGzticd2l-x(f|Y5P_!2m+0f<}*cfmn0Vcp< z19BNS^CBio_TzLJnnK(m1Zo9?$~I8999&m0K<b4O<}6U-harnKo(&R$@$8_+;z}mJ z5KRtnQn)1sF2CZzEsuCmvjb#nYSAqbh<I6I4w7gQCn%AyyLpCyDpBwdO-br4rnJ0U z9AE=7^U|?JNYP3L28Kf*e}Xa-10xqB7bqPdl_;q73~rw)frlWnpfj4_!3TychAftN zR;d5LeTt$3plD!+cu*AL5l{xl^ht<oFvurGM?g`26zn;WLR|h>&A`BL0^|=+5eM=I z4`UUkFQ6u1q#jV22u?kqITujVJ&U1?u}G{0)KF!rVQgj!2G#QnHH?xB!3^LAy(TlL zhYf1X#mC=bPfsmLEGa1h*TT>?{w?P0)JjMOL?klq_;{$&`1rGM*D`^_TojS^kyL3i zK|%m)qK7*hREpFvfC3*h4+xt1hdGm}=oBdI*h&jh5=&B{Zi5&Jb{3cbX8@k~_~O(O zNINwCGRSeDJjH<0q(oMtDOJ<~iknUl0qQsv^?+F2AOh4qyv0#alwXiqR8o11B`H5Y zr)WAz29(T-Km}S6s3is-Xee3+k^oH(6|Duawu6Y>AOciA73~ADKv@x-R*OIpT?7id zqVph$UJwEH1egGKEJ5BW0(GMrLB0S*DF>q%2NR<Jvj8(6vk<!wn-C`-2OpOhh|k8y z%Ety)t*LoSu%Hq=br=tCpBI4w0X(=FB@C9*%gfVCE6oGXASUJ%8Gu5U3rQdpB%lIT z3Z2!{1CQH+<_{A=(~f$H(6tocS<74O`NevmMFY23KqR;_y2YHBlw1Vr&lQ1&aBs1I zmQ_Hy6*3qO)B`o@iuIBbi&Kk0V}iF7a4CWJUO{nyO&PeLNG;L>r-5Qfi6RavctGhw z55$Pi%qvMPN&_`lKxO|ewvx&MP>uwRG!%hj3p9{a1S*3egG<Js&|oV_%u7iuE&>hI z-4e`AEGfvzFUiSF(krM0v2L+J=%Ol+V)o+H#H9S9yrLRV%z_d&IC7A35;&qkaeRxz l2GSF=1Jz2!pgCvIBq(S?l!sA(kp~Harb<B#Sq>I~8~|;~>=yt4 literal 0 HcmV?d00001 diff --git a/datasets/__pycache__/custom_dataloader.cpython-39.pyc b/datasets/__pycache__/custom_dataloader.cpython-39.pyc index 4a200bbb7d34328019d9bb9e604b0524127ae863..99e793aa4b84b2a5e9ee1882437346e6e4a33002 100644 GIT binary patch delta 4619 zcmX?F-dw?#$ji&cz`(%Zt-m43A$}sC3}ezpZ576P?G%wYj47fiVlAvuI_?Z9;weHc z3@JjXs?E$%TB*7z64^`>n2O?37*d#Wm~&a8SQr^n7*ZH>^rG}qSr+IoWXLi|k(6YJ zGE9+5HJr_mB0ZNm$_UDqnadny3}wsCWsWj|vgPJ7N13MTWtpX#)~Cp)DD*NhGNdR* znWt)`8mF41D5WT;ny09wsP?i&S)^*GTBcewGp4ACFf=npS;6?JRuUk-bt-F?O&0qC z+f<8Gt5oZBriDx~Oi^}G_Nknyb}5XhDyeqOj46sK>b)#c4iK>v#T1QR<|v00hA77r zdxmtzRC5uAD5q2h7lv4aDCbni6wPeb2~0%}j0_W)3gc3pW0<2{qFkfgQr%PCni&}x zQXNwy%^9K`ComQ2Y@Wqr#uP5z&XC5K!kWU?!WqSq!k)s>!Vtxp!kNO=!Vtxl!kxm? z!Vtxt!kfa^!Vtxg!k;40!VtxoA_xlVD6SOY6p<E&C~j~hFhudRGq5m3@dh(!N^DkV z6=G6zdkKoTmkcZn49OgzD1>4W1_lN}1_lOZ1_p*=tH}xMYK&}?_bZA`p3BZ&&tAh2 z&z-_p%TdBp!;r<9!qm*vpH<6Q!n=U4hI1if4YMReEmsYf3q!113{x$4El&x54R<qR zmOw3c34aYk7DqE9NF;?>oS~UfnxU4rhIfJBLWZKO8jyJ`HQY7~H4Is-3xsNT7cw$3 z6lR4pq%g2BurM?;Gcx1}OpaC(tQQ8UXl8U_h?S`2t6^9mQX*Qzw?J$mLo;Iy>p~_* zhQeDV;tM27#Iqz*SW;Mf8PgdTGS)D}OQo=-u(vSO@~3csX@L|@FfEwE1*U~kxWTk= ziFB4siFB513J)l3dzoqkN~E*oK;fGr*vk~t%vd6wC7&V$3gBL*66q|36rmK+6p>!0 z8vYXLEX5SD6vh<s6p3DDuu7#G!4l~#<rK*jsTAoH8IWj=P>FPwN{VcXT#6Jpp|UhH z)_|;4ot&wxnXXo%zCfZxqnV*ZeSzjehIA&d^R>XD>RH+;Vkw%vj3w$>IxyA(-4rd5 z(-$(PGlSLYfz|4#@PMIqFEffs1}O}|44OKV_i|3-4q#wl_za5qDxJv}C50y&2(eGL zS8<rUgUf<3aPmj4yvZN9tR~OoHk=&K?LJwMM};wDvKh~Hrk5g<`FY<kHckG+Ys;i* zI60iJdGiy#^^8oKB9r$CNKU>ZV9n;mz`#)CJy}+;j<Ia=96_ncI|SpIs-&RyNN}+; zFfcfSQklr)AR(>#q#7m{hFGIo_8RsDEDISHu$Hi8u`gt-Wv^jdz)=HXF)ZMOvRD^z zfmn<Tg*th9V4WIRbYfGVR>Hl2r-pSQV=Ze9YZh+__X55ehAjU26s8oW6y|i6g^bzE zMQ@O$QkX$<Fqsm88WvD~EfD~z0kgOl2-dKsuq<S%Wi8=eAXLM=kg=AzggZ+(OC(FQ zhPj3zODu)8m#LPyhP8&dh9!+Tm_d`x?-pxGQDSCZku(DX!%Jo}1_lOAmLf(FhY1vN z<{)A^NR;swcV203d~!}=adB#~ruZ$6^wg60l*E$6qNd5GgtbB<85kIDv6g1$l~~;3 zO)O2%P0cHb&&*9sPc3Q%nNk3jFUc=T&bY-|T#}faeT%g?Be5X$7E5quRccWdSiYbr zGq2<pdqHYZZfQyCWH*sU{Z3HyXo1oPlK>+NBNr1FGY6vpBNrnVBOjE-!^Xv8!pOnM z#mK`r`HzSXcLW0i!%C)G%sHuflY>Q77}rg%5S_}nW-_0cbYdRJHOU|%5k!F8rO8&r z2V$p#h%N>OhA58m;>`HG#N5=PH6V8}XQosX@q%Q*1gLB(Dr8__*a&hPs0?EOxrd31 ziG{IBadM-K_~b`g;`Ja24p=cZfsuisgkb?=2SW<uLZ%6f{R*{A=?t~Z=?t|jHB1Yb z7BaXn#M;%e)-ctuN-}^VyOl|jp_Z+NZ2@x$OBN`tu$8baV6S0Y$XLr(!@Ph4#A0Nq zVOhWdN=b|ig$gyyMFk~n3pgQ>U!Nya!j;0@!O+aa$WX$zfV+kviw9I3rLZn!sb#KV z&f-mBOJM?8Q^K}@FP$NUeGwx#<z2%f#{rWo;p$)jr#LQ<Dlm&5oW6LPnPQk~nQPf= z7#0ZBFn825EM%(XC=mqLAfOt9qeQ5NAxjukZZ<PEGuCp}aDezVoHZQc43Z2f+%^p1 z3@J?TvR25DVFF{ZnE?YM7)CNMGL(qaFl31~GlG(HHp>LYqMj0Qh)5@62jc>X4hB$} z4H8)(sn)@m!XpV5XlG1g1Qnz$U<+M9#WX{_a0*i}gC_svyW&kQRf4X$1tpaV&iQ#I ziJ5t+MOF$$pmYz;7)79LpvjCR0Lm_!%n(Ts8!R<BUc#Ai>*U1}hTJznNkxr;fuTx& z@<mB?uJqKr)QW;4yP|EA1tgQzwt}QVxeT1TiWETH%^+e6hyW$Al}tsuC(o4BWqdOE zl&s9;8XeKezhyb=HKlH`WSACI-ePgf%t^h)lAT&vTr?M?36%Yc=7CrXKna#BttdY? zKCd*lpt1;*!iyGyq?UlBSc?+#Qu4t0t!NQQbTLR&up~b|EwdsuCB8VbDm6YSG5r>A zd}2yUd`W(Ma$<5u>MfSM)X7;=!i*ax*GVZda!p<!C8rw2Rg#&L3eG6Sw^%^zA^}kL z;R0oy_@u;=<cy;2lg~*>xgQ4E1j;wJxZ-p26H~xN$1YIzF#=_6R#4VvVdP;FVB}%s zV&-DxVFB^D7=@TDz!@7PF2iWFSzcO^Q5@`}4WK;E49U2lhC<QS$p<xsSx+-CFg%-_ zFH^y|XYvafMaFNFxn!qlfpT%tEl@$mRJ0r9lszB<lzk!D8sxH~BMb}-e<z=k4Q1Ro zSwT*O@%Ls^xpqbeaBOV@sRx?^imjE*Mc~3C2_$+I<Pc~H0y6d%Q(8e0l6tVSW-%}@ z{GTkMV9dxo*;7GVWgST8CJ+HCM~XmZ6oH&xbO$7M4@AtGJVQZNbqa_HcAgjm0|UtN zVo>vlgOP=sFP1}<gPDVqLud0X1qWtE&&kp%?Mz;blM}g3C*M?2X7YwGja8MId?3tT zRb@us$s1MUjX||ukqgLhR}kR=B0NEa4~Xys5rH5g2t<T}h_K1dYEdrPAZ8AT$OWYv z_M+5+oW$f*NGf9miGw{<1Y#9~RB(gK{IbmA%>2BfW{_~{<cDe!qU9iF35X~I5lSGU zVlt<?88<hG*#aV3C;O{QMSx1=qEwJstSPBUrRld=GV*g%i+VsRdO-xZmhS_x`a#45 z5HS%%Oac*;K?FD$K{=}k9F_?T3=DFc52{NsCQgH-3UCqvMX)A!5h${Xnn2oSgNQT` zF$YA<1rdm>HXp<V*KS~qtdPXG3M2~329S)s799Ve_)eW1qL~D)n~K(h)NBOHfD%E) z<eQpJZtx;NleGw3{er8iS0JU})CNlSSgNYb+;ngbgcK&*!G?n3pnCI8EfYq@9h2W_ z`z!ARnF97dDAN_~gLA>bv7do~p?`9Mjuqp!$+LALoWV&RTtI;mToK4&n#@IDQLq3w zb%RYi05a$hhyZiJrh)R&tjT7&w)ICr630NqIS_FiM4*KsIMJU0Nt^`{V2^^bb`dyl zfnvSrJV@dKD4N(IAq=jkFM>obfe3I|ZwIk1g9xxgu7FrvVCR7HR@3A^y3@I@fdsFE zh#TNg0@*!h@;1FBNP!J1ABt{+Wk9)b*<@?|<=po{${&D;hhP~{+E_iA)xa$283O}@ zn<j6OFh~-VCyPWutj8b%l)j2Uab5(9UT|9hJ&%J7DtZD^40bS>0H;#0qn|P`FbGYa zYaq?|bn;FEC&tdne+;6H!0j5Aq8N~_SP&5hBEmrh5H#_jOHCFu6K6cQd7+^pBOfT? z-eLn4|M|H^;gcU28Oec_gQ|^N?4@~`Y57IDMIdu;F(;;^6va(8Fg9e|J2}bNg7L)U z0CSbe`;FDPKY^_L0wyMZG`5fhwa@l5GB6agFfcH1FmbULfka+1FflM_GEVk3nPFQr z1Ed;MooccdfvYDDkRw2iJa7$ki#a*5<Q8jHYEgdiE#}<JydqEn&}6;EoRpY8S;^E= z`V=VhgY4h|RTj(=j2tX%j9UNMn2Sz<LWdh<V-dt*lSK@jA$jZ@$ky+m01$o+(hF*e zF)*<)f?&~)$<}6{lz)Ph`~ne7j0_BrssJ1j;5ZOsU|{$(Ilx?1iItIo!B3L~Y6VC= zH~>JdDPjZb1g8y<GhR+!Z(hS_K3T^?iz6>TKEEU*wP<pJg?BwDxfFp48@D1**{jJ~ zlmRjbT)2QsxJ(ci)C^9aH4_A0+G{f1VlPh2EyzhNf;7L&^NX@mi;6WFiv&ToN`UeN zV-Y(e1B0eV5jeAfiz#qKvlk?mWaMNfL5h+EleH{ElfcFA8&C}`3@(-R^78c3O7oIS zGV}8ibBdOOT(ANh)^4|$lZ(r4F$Z}16tRJX|A2(q3KH{D5{qv!6{Hk_>VPP|#GIs3 zP*Xn<WJGaM$>fKY6&j$#3r+>Mm^1TAia@ak&Ly{)@{5aJf=mM!Y9A(dTDh=+;<_k$ z@?|Rz0UnU2Kt(B}<XkdY&_sN)k2ODI#^e}lBhF%>dImiKX#xJp)2$Wx*;yF5z)b@- H5X}kzkaKZ= delta 4613 zcmZp!I9kq^$ji&cz`($;es*&bOUy(*8ODN*+A56otXbNrIw>M^m{LSj#9CORbln+J z#8ZS?7*d2%RhyZkv{Lm_B(j+%FcsybFr+Z$Fz2#Fu`n{EFr+Z%=tt?NvMexI$dF~2 zA}PrbWt1Y7YBZZ6MS3oClrfYoGnYBa1j?43%N%72Wy{TFjxtNtug@}1HA|6CQRro2 zWJpnrvPjiPHAyv1QA$xxwMbD(QSD`mvP{)UwMw;YW=v5NVQ6NIvWD?fttCKwn^e{; z+bs44cBz)B)~Pn>ObeM}n4;{X98x(`?Nb<2RZ{Jn8B^3#G<sR093f&U>M5GN%u$Xh zz6?=LDGce1sTLv(QO>E3E)211Q7)-2G0ah}QEpN0sm`fP3p`TY7BaFjNHU~Inlmsn z*f30BDl|y-Om%K%WMoM7Otk`80Ag?MVKQS1muP25V@zR7VQ=A#VoBjh;cQ`uVol*n z;cj7wVoTvk;ca1vVo%{q;csDx;z$ul5o}?I;!F`r5pH3K;z|)o5p7|J;!Y7u5pQ9L z;%R4KVTj@lX3&(}EXgXwr0(bTQiOql;UxnL14A+=C`O@Jl!1Xkkb!}LnSp_!IAU@j zyBZ_=<O+7#dX5@~c<vh38ishD6vkT465bkyEY1|BW~TnETCNhl1^hK!3mI#eB^heD zYq(t)V%=hxYI$mTO9X0oni;bMYk5ipY8bLOni)YNDa_&w&5Y6vwR|;v3xpOj6cyBf z%wwtHv0<oT$YNa}T*J4Jk&&UWAe<qEfrWvEp}v`!ks(jOkfHdD0Rtl#Mlvuml!$<o zH#530#7fli*Dx#)EfK5XUm(7ap_#FUZ6OmQL*auGi3O4+5?N9yEGewLjOmOE8EY8g zrNJS~P%Dtak;2}>P%D_i38sZoxWKe<3OAS*DUr#NEs@EpmjeZP3STc%jbMpPmV637 zD9C%6ni)%EvJ^nUo+8rAR3ek5m?D%S3S!p?l*nW$rHG|4riiCV^fH50D%S{=$YiOc zNTx`oNT<kvL~DdgWU^FKWK-l)q`*m-rJ1n?WUX3?Vv1sl(j2A~<rJ04Je;!i>LnTr zBug}#8A>!3Xf0$&X99ag8!W1krIR9-qS?z>qLHNwV=d52(E@p6A!9l-Sgk%-tw9P8 z7;5)2qnKn^!;r<3q65ySAtgo&j8k;Mx=d0Sf*CaRCQsm;#uzwRoomiyBW|n7f4NL2 z&*gTXT*<A%7&3VtcNSylWKN#zOhpotAMw0lY@U3c*H-QpOG;*5ew9dCVsdIyetxz- zh^t$cS(0RErfE8vkFS}7BPlUCJ1IYJb3flkMz<mk1_p*({Aoq`x$y->sU<~;nR%%x zw*=Et6H7{qQsYxAN{SMbOY(~}MT__t7#NDuK|}_KNCOd>AR>Emgn%8JCy41ad5%CG zW6fk~L8-~cf{AP#EIf==Qj-n2q$Y0=6y#!KU|?_trFEIfrv<g@t7@2B7-GF@*=yJr zuq<R)z*@qV#lDcSmK~IKQ<$=uiXuSizJ@)8IfbQ{sg}KlZ2?yeRK)^rD2sIg4~WIc zP^gm^K+wc32ovkUMlvkm1!)Iaip9vf622PNX2ulOY^I`DCHxCO*-8+Unp4;pvedHX zah32d5UOEa$XLr-!oNVchIt`lEprKfmPnRpmRJpQ4MUcA3P&$fEprWP4RZ}k8gnp% zCa2#mR!~R><Q2t%f}g3NM3d<jTV`%zdTQ}4zM|BU(xSZhlGME7{G#F_Mo`HCDj<ra z85kHe8E<jtmFC7L=Oh*vrxt5U+~P=2Es0M_EJ-YCo~$XX6%+w7lC?B5uf(D#8x$u! zAOYUQ()8TaypnjZEk$i0i9!%j0g`3TNzJ>(T9RLsoN<e_C^0W3uP6&7&RSfOn4CR% zxp1R(A1K-kK;;&b03!<{7ZVpV2crNZ7b71N7l?%MKr)OxOgu~?j9e^Sj3SIxvXc$D zL?++XG~y0tU|?9uRJ4?VfnoAG5f$F8ATyajW-?Yuf>nJMDP>$gIa5?RHXr1;WDt=E zB0z4{WJ5}Wy$lQtQ5@yPnelmvxv53#LC$5)OsOd11t|v;ppvzyh=GA&2gn|9P09#z zD;MMBhoVOHpn{kKR`?1rGBA`dEMV+lNMT&a)W58jDV?E~xrS*0(?SLphFFtYmKvrS z7D<K_rWB@DCP{``)*99Y%q1*Ykf>){zz#_SHOvb*KrBXvLa7qA1)R_X!39nV+|5ie zOts9lY&8rEc<O7II~W!+)v}lHg6kgU6qXi_622OSEPhZ4+sxR^Sj$lZ$_o5795w9X zpsI}3h9R6Gg$Yzd!%OHLC?zyFeKs@JFk}g3vrJ$t;wurZVaO6}X6$6_U|b;5!2l}g zK_UxSMLQT%*d{CL2=nv1fC_Mic>WY7Q1YC-O}uGxu#N-EN~WUilMhH3a-Ro<sTu<V zLzVtyLoVUT{F1VaJ0|N&CaY}+$$*k5I0lLoKwMC?7HtEuK+&+0sc7%yZIZg&w?HCb zLkysXh)mw6%M~hiizUOfpz;=rTV_t`Etc%m%HpC$Ahn>BT(lU(S_X<;E^sN8SDIT; zSp-TBMJqs}f+hL!X_*zNDe=XbRjKhwiRs{^JGo0rjdAJZjZ%tUoQPC&i#@|MJ}omR zHHs6Ia*`5Dk~40xlw{_l772h1V~j7_38FbbYT{w3X(bb+*gHDeP+CN3KPWLNff6k% zA0rDR52FAh4<i>74<i?tCBta2xm;S2Q5@_<P*PvX3~>NB4{x9Rkz1Jc3<Cqhoyo^# zDj4@oc92zMd^0&ocA5z&4Hw-3WpbvXy&!Mz0}-1+1SlYj!0GuI0|UcnkPtK^3uMJ6 zua{$+>>;Nw2387Eb&CsBDC8t2rREfW+}tbI&d3dpC{RWyT0dD{K~(xWC@(@Y6v&=i zOlbv0P<aPgakirQ3=9lkC+8>_GqLzhUZ9|@vJqtK77zi-Q$-*Hia;JIx(5<_03zm3 zey<?QIDIm&qBLg`hdKu{2Rn!CW<x~>W^Q{%28JTAn{_51)DfQCuad>+FnOJ-`D8;? zB}PXut6Nox(P{EJ)ksF?$sB6&#u6YiLG?wE3y9?oB0NBZH;C{B5dk0~5JZH4h|tNi z)S~KhK+Ieakq3%R_M+5+oW$f*NUC84iGzJm3}TglRB(ffzOu~X%>2Bf7LafmNSG@m zKfu$+6I3qWVhc(IC9on;d0A8nQc?~gKnbC!62z(k5!E1Ka)r9Gb~A{@4I)}W1gOB( z6oe!Mv^unR@<DYm?tYM}2_V9A@<(+Exv3!LG!QWzL?9AC0s{ks$Yu!*DaOQ^khB6$ zSfIqC$qkORBCy>HK-$2S^g<8|Ek`W@iGdsi&bOMZ5SOh5iGp$hBqMJCCst6rM^9d@ znFKDsi#CDOYzE7K5<}i(1Fa@Ec$U{>EdmwjMd0H54M-`z;yW`p9Z`H2?F1VNikiaB z0@@~wjJqbg>i8?~2AKl({~i$Q0Ehr{K`FiHAOizK>*Vb^R*X9)f6|F?2B&jyfdns$ zG?|OQqF@1V5(jHP1TyFdhyZiJrh#(Tgvq_Sw)MwB5+^_exbQm(Vxfg0IPsqaNt^=_ zV2^?_d{GjJn+hT>fQXBrXkv$iFgWfnfkZEZ2yj?~ikzaWATHP;*FY>Tuya6ptYUJo z-gNF8Ai<j;;ubiRKz2`>%&VWoeH$bQDm9Akf@MItan9sP`pdZ=f@B_nh{s?VP>x<W zImW=O{sjXAgPSIAkuXRS)HNvr^+bxEg1DgcRRoIjB2e^#>reE&4l=0d8AvhM!C(TM zO2LkP&cMLH!N|Z+ECgn-aPzry2!o3PDGn|U)}rSOlLZYO8M`J28%7&}n*=OH(ICTO zKtwEv2m=*P&@_oIH965toN@c+ABKjEe4rZQ78|I+lAl`?Hrd?RNDizVRB7B|FU`wL z%P-0;0-1Y@IWZ-rD0XtYu_5FB$vcfL823$HV6HM*%tW303&_fEU}Ca|iG@-Z0|UcW zkOx>m9$?~PF9M0YWME=o&}6#BoRpY8d9KL}`<LK^J{J<X;2Mep<dTBaqTE}o#TkhO zskfMu6H9KfR;3o@7vEye&CCNgeKc8%xEL82CKs7nN}pz6V7LIXiU(AoFiS9UFtagg z{byq;Iz3sy(1{ylaS_CQld}z-Avy2|$f-X;;UIhuq!-kPVPIlo1i_+Tle^76DgOp3 z`2!+YK;<qthQI_khCmVXbn*goRV8*t1_nP(7N`{<_26&-xuysdhef>Llml|coyq^r zYZxsj*H~zAl_cjD#g}CkCFV?CXW?zZ4YC7N7`YXJifm2RB18`X)J`h`cN0J@%Jf+? zLExo*5hw{x*0Ge<5oH9mYKu5Qlvfcr+kp!#aNM#NB$j04WF|ohm1UDFEJKsP<?}mG z^G_IDZ0qIa>7|wCC6{F8=OyM8tpfRCHK_Pzbi2izTwHdGIl$AW2-Ni``UetbD@e>s zNi4p_RFF~xswATL5_6JDL5<@?kP*d2rIYooDhxo$7##7p*h0Wf@mtK9c_l@l7z5{< zTTJ=IMXy1Yfs46Mlh0eZuz|CG)MQC(4_Qzgg9}%1S$c~tu_!$^vEmkMK~ZL2$>b7i oWyZA0Gpvm`Gll9I^aP|ucqU)AR^(-4VdP@uVCG_!V&q^10C3oF{{R30 diff --git a/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc b/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20aefffd1a14afe3c499f646e9eb674810896281 GIT binary patch literal 11037 zcmYe~<>g{vU|{H}-;l)a#=!6x#6iX!3=9ko3=9m#W=sqWDGVu$ISf${nlXwgg&~D0 zhdGxeiiHs(#u~+%!jQt8!<Ne)#SRu@$>GT5jN$~dS#!8@xuUqhY{ne!T%IT%D4REm z7i>0PE<Y0kBSWr0lps__C`u@WA%!hRI9DV}1S%#PB??w879|d*C88w2v}BZ23PTEe zj&!a}lnhucM~-Z+T$CJ`&6y*gs}Q9KW;5o<M=7N+q;TaZ=c+`hFfzC^q;RM3v@oRb zq^dSEN2$3pr0}NjwJ@acrLt$KH#0|RxHF{irwFt#qzHgjXr`#AYNZIxVN4ND5ouwK z(spM^5k;0y)kzV{W}3iMWP}k)x>35REDQ7&GGyteh)Xg=8Kg+08q8)$k(|pMWe8<U z&1H@<g0iLOGDjIh*)nsPqfAnDvrJP>Qe;!)dYKp*Qskq|QZ-VIQjJp-QWR6oQj}7Z zd)cDQQ#n)3Qy5dkQ&f7HqbyP_Q>{`hni-oJ85tllsVb@F&5S9kDQdkeQ5Go-QPwF8 z>5QpnA`DSBD5@=77-IFJY*TGxn4|2X?4uk~ZBkhmIHuY!WMpNKWJnP=XJBTqVMw(G zv7)ReFc$twwM?}E*_6T<%%G|95|pC+G#PKP1SFPZfasLOl0;3$TU;)QC5b-yi7BZ? zAkmV-lw3{5TdbbBiRq~z>5%-8)V$*SqA(*(##<ter6u`psfi_}MX9b8B}IwJCHWw2 zT#2RWxv6<2sYS(_jJG5_bCVKt67!N%Q$U6nr<MdK<`(3n7A1omfsC18Nn3=0fgzP4 ziZO*DiYbLLg{g%hiaCWjm_d`}mQYY)UV5rueo<~>PG(hNNoIatGDsiH9A*Xv1`yvF zltn=1)G*dE)i9+nN;9M|Nid`^Nizg9XfpeygEc(JcX$qFFg&nd$)L%6i={X<C+!wn zG1!H-m@{+Ji+C9r7{J6YPyLMi+*JL_oWwl+^8BLg;)2BFRQ-^m#Jpl(Paj>E)YJm^ zqQtZkeNbQ)r<N4!CzlqN<mbj`6{N?5#B#s^t5;BYiz7ZhIWZ?EK3)doP8N_G7+Dyr zq%nP`2UF>%$pX@rmzbLxAAgH0K0Y@;r8Eb`=82ClEKSUT$P@{|+zWOzgb)OYaWgP5 zNQ0~cg%1Z~5r~fw2Vkdw6eXd=K@@X}csoNHV+v~uTMK6tOA31mM+-v~YYJxyR|`WF z8zi$vu|qOx6bB@0MscPHrU<n#L~(&LB|{W<ifD>h3qur7I|B<t6mKwtro=4)uK;&n zPaja)fP_!tFGdE2pwxn*)Z)~<l46C#JcY!hVsP?O@NjW6RR9HLacYU4f?Ix(LUCqZ zdQPf>hp9qxeqKppW?pKMq5_wKf&zr_%qu7@Q7A|(O3W>`0t<tr(lT>W;|mf?GOQFp z89<>VU!f!;RUs`uCndEAW>98cI#?8Hoq|SwUaCTVkwR{1PDy3~$WD-BKt>g7>cNf9 z%}+_SQcy1`O3cht2Ps!CNi8l>hs%Tf2u=)E3Z7{SAw{LBItn1MVui$<9JqZ7nYp>C zDVd2SsX3JjnRzAo3PFy(dJ3V%sR|_-nZ*j3X>bEU{sB22=FH4ug|y6)Vu-HX)ZF}{ zN<9Vtl8n@%^2}n8WvPi}P*W016w(rNic=L5^HLze1u_hTVFoIcXXfO9-KUU~Se2Pm ziSQ`MzWC(C<c!q#;>@a4D+RyO+@#bZh5R%~EP(??p(G=*L?JmbPa!E)0jxMhAu|u` zO)yVEBPBI0u{5W|)>u=I3+yMim!JypB?AisLo%qc1LX}67G+>y0HtG44lgcYVqhp? zXl7WzxR8O7A)TRyA)YA{L^6X(7O)6w4O<OEJX;NO4MRM84Py;MJVy;v4MRL<4NDC} zJXZ~S4MRLPSS3#iV=YGsZw*5hX9`m@Q~#V=&Jw-_{56~l8Ecp&8EUy|xLg=wD`J>x zxode!1ZudO8M6dyxl06U7_vB;85amGWN=}K4XfpaiPZ2)Go&y}GeFczGt~0c@GTHt z$WSz)M5Klfq`sN4=vEC+4MP@74Yv(L4MP^|0?``2g^Y|0g%fHRvUm`@!h&#y6b2Rs z7KUbKMut2ALx$pi1`Lc~7|FoMP$CA>-OT915Gz&7U&F9KyhNgge}UvehGxbZ)`d)r z3=<d&g-WCrNS8=u$)vEPu=X;hGcE*$hinR43VRDftw0I~m=;Xo1k*w(Twq!_g&RzZ zl*nbtm&j!)r0{@pcrR0pV2NCoA}E)q2=+2HGnU9@DWwR3a(ORPiCmU)icpGZibyY0 zjX;T9mP(2kxZ0EGWd^HMtr04b%Th~`Op!{FPLTnL)(Dr#WvPQI&=jc@`4oj-mS)Bp zkgXaiiYbaIN^_V}lv7kvq`;LUYcpdybBa31{B*Vy4G>+!5U*LHwLrQ=yP2UxYk|%} zhIA&d4|Tz!T3LE2Vkw%vj3ruG`Y_f4gA^^0?-nwqGlSI{g4G(O@PMIqFEffs#wnnh zL&q=V7GqHqo0E^fvv=@I21W)3O~zXsxrr6vT#(ENN~%x{VzV+ZFa(3jQvn80ip^q7 zVXR?DVM<}{1=nYOMLY})47b>W67y2>a}8HA-eLh2<13kNF=r+w-r@|-&rL1K%uOv` z$#{!7u@tE`Dgxz&TO2kyiMdHBiFOeT3=E$^_Eu@5m#2F1@tJv<CGqikHaYppi8;k~ zdN5riJ2!pSOb|F}4|B3nku0cgkp~e<AOc=pf*n|-!oa`~2@+RjU|=BGy60!lfWRBL zb;d+k7e%&pMe0bAVM2se(bTcZlnAS0U{+}|Wq$tu|Nl!+K;L5Zt@H>E(`35ET9%ko znpz~sz`&r%lx_?LukF**z~GbpEq32Z_r%=XM3B;S5Cg0flKMD&E8TK33w(mW&bS5T zhPu22C18-wp!^a`15M^zjM=w1vr|(Gz)2l!fF@J=4!PFXfjj!_)2l_+xj(F$VqXN) z7Lb{moT|xui?uj8F(>sFOLAFa4pi666b1%{mnT33sF8My)3?$!C$YFBGg*`67HeWo zK}O;&*5b_c+{7XakjJ6^0{bTIoqakaU3{?DWV*#zaf`81lPU8uC{T34{(t!g#0E8G zK%Q|-OUukl)nqL)2gNR1QGQ8cN$M@u;*8Y9B82H6fBNU8`lsDu^{sR)P1j_+#aMum zU_q4+I6H&ts1k++j5Q1k8G{*CGWuyU-D1isxW$@SRFqf=(wqosdV!(|krlc?o&jYA zP3Bwt8KtT5#Tl7tCGk0#xtS%m_=^)uPzA~obC3lzO>VIkmlmWJff|C4JbjB3Tn}gF zrRCq^0yDw&>Mg$FoXnI|pTwlp9GA@Gl3T3c>i8BvLJ2q$Z?S;dE4Mg7vfu{MEv_7p z9tfK=FFrXZvA8(3_!dV}Vmd?zPik&KNo73P{Nh_I`30$Yw^*`@^Yd=8fZ8!dpw{Ir zmg17s+*|CKC8<RznMJqQGK=FuG)qoqamg*V#Dap<yp$qP*?5aTJ~1T)+#revd7}6h zb8$)0E%u_+;{2Sl)LV>>x7eNYa|`l|Q*ZGH=O*Ulq!yR>CRPMzR;At&EJ_76J>pA1 zP4Kk*qTJ#l1yCGw1*aC4rskDoCg$7{bxcVK$@eHtb<9f%26a`yeTG}S?x0NIoS#=x zln8PXC<Eja-C_gLr75>q!AbrWD<nnmgHtl7`2<Nzw?v_d72Ng*Sqw3h4Js`FZv7!S z>lPQ16G83nTO6>S)Ga<xBqOA6i9qTCkNl#{Do~3y2jnDhzb7R>H!QWNBr`b?BF5$e zF8gosW#*<MTP6WDp)?iY;V@9~aEmXaxC}|9rf3l#C_#b?XmGI!sz!=HrP?h{P-KGR zr6?Sv22}ps;w(x{E-fm~1Vv2REzy$H+yY2Kh=-(_l+5IkB72Z>&=5rwTXKFzeo@IS z{)+hId{CdS1k&n`;smE2aFXCGhE;L5xZq49B(^aU+XRVix{~o0XFMc<#>Yc48>mh! z3SwYj&;eDsb)fo?k&B6q5edS0Of1ZNjC_ndj9kn@j8cpoj66(SjC_oGtbB|T%q)yt z|2ddB7@?4fg^`C*fRTrhhf#=;g;9u6gi+=%6ARNHmS0RfOkX$@S(sTE+5T{_voLZo zaxwAz;a~&lVEb1kiL)|B>9vB|xuD_-+-p6=z`#($uz;b4aUo+2b1hRXb1h3cLoI6! zLl)x#riBc(Y&8s7%qfi7OhrjGOfC$uMz!oU><d^HGAv*%VasA)$XLr>!?u8<2Et-k zzzJosF5m*O7#Rw6^7OzuHL&Q!rarBNdjU@k>q5p_))MXoyfw@V8EctKxU=}O__G8+ z?dvST6sBILTIL$o8s-|7G*H`y+3yxBI3H*--r~+H&4t8NYO$u!Espfm5=hc22W3}* zlKl9T)RNSq+|0bp;*!kdB2Y5C#hO=|TTpq6EwQ+yvLLm{8>E!AG&8Tn;ubG-gdsi? zG%!(A1xj(LVEK~#qU4NQti>gX$=SC!6N}?Zz~dFSnDX*&aix{!C4)u`5_4{`=jWxy zXXKZF6T2n{I3L_%&CE+lt+>USlvn~X1e~V971k}z{L+$mh%jeyYGFJm$E6g3+H{bF z2QTGsamUA}r<P=vq~^xQPXwhxP^*Z6Nq~`sk&B6onS)V)iHnhokq^XT<YMGu<Kp0A z6k)28!IDzo24U1gph6a$SRnNfC}}h^1~U|Sf$P6p%sHufMWA|WB_r5p=#FHMkIzZX zi;v#{av;bo2F5B0EDiuGN(NPWpsWPKYzzzx><kPHV9Ra5ZQ2^fX2vYWT4r!MU@BuM zvZ!ImVgmP@SW+0l>6Uo`OATWUOAT`kYYNjG7Emt-+?r<gd-?zW|Nkplpw$?tOI|b= z6q}$5{}v~>E{o4hsnBEv*DFPdAh}Qw0rFiDD9$z6AYKF)VvV3gz)@bD84qekgN?ey zoS9OA763(Q3=9mrL7oSN9|Jhoaxt+mRw-Zx0&4Yyl5IH|7#KhS0t%pF8%73(bcR}{ zbcR~y5{3?jW=2rw)P*6|rk16IX#sNy%L3LKmW7N985gjXFoClvYYE2!PDr+@VOqci zVlgt*FfZV$VaehK1yZ3%4O3B73C9AS8m5JewM-?vHB4E2&5S9G*$hRCYM2)Ar!axk z<!O`%q%e0dG&3=Rx{`tmxH=d>;w2mlglZVFgi}~R-M)n^wM-=<pe}Hha5Ga3Q!Nvy zC|Mv{!_>jBkg1lvL<~H{1L_iTl!(_bWJ!QJm7o#IS`JWWIZL93qlR4^G{VSX!vO9s zOExpsFl0$(vrJ$tvM7<RVaSqfX6$6_U|b*raREqVfoun33X>$*MeU4fjNm~d4v_gS zpe`yyyhI9jFoPyfRTXID5j?yD8<7Dg0EP6-vQ)^R0jQk<9wh(`Qh-tmc<7?EASJN` zG^PO^V^9F+W9UEwTsgSu18QR^6qgj0CWA&T5u+JJpaczWxmF3ff*Lpqptb>cILS(( zND9<Y1;sK-Mg$3furVkhfwFlGL#$K`Q!Qk~kwKE7gCU!#h^dA#jDeA%kO@5W%2c$B zfq}sfoE|lqZ!zf^++xhU#aO1vSOiYMh=PRECOtJTwW6TN?inb(gBp!hdYH)*UQHwQ z?qC{=mZNvNKzSXMeT6|zwqgLK*cwm;!dSzY0(Ni;qa?`DMi71tINoX)Q<x<gKq+Vr zOCD1V<3y%HmS6@=Hn3N~VFu1<;F7Xv3COWaK?KMvnoLEjL0OV1B~z2}7Ee)XZa%2F zlbu>w3=RTt%lH<T4X9lRY9iQu1nC6(Ri8M&!gcy-%0bEq(8wjEwqQ>z1~<oxK$)~C z4rKHs5HT4<Oa-}%E3GI$Hy&KkfD02y0W%FGzYrwPTAYzska~-=G_Sa@G&Qv<^%iSE zQD$CAQ7l-tC^0W3uL#^U5-iD&Ps^-GO#!!(<C7B8Z}Ebgwk7$HQL<Ysd8rj8w>VSt zN^?_-5=&CS1$0pe$S!aT3)CgL#Rd+d)LYC2`30Jsh?4CVC#X<^)IeM%py5kU0aqNw z31)+G0=PE1#RU?E)Ko<)L2l##Wf@Qr3Jz~gCP*^d0J0U-TDrv*4^KfyK#4*RRLZlm zFtV@+Fi9|SF>^8TF@c&jeBcHW52Fw(s39c7D8wuTuE+$KtEBP76iVd*$}AuZ%4@}- z)CVq$7|IxnBx)GbK_v-84U-K+9%xvNp@vBk)W2c`mt#yd3|U|iP}wH|8i--4n&_C4 zQjDdv01ajqrIwTy<rOPH20~L(EA$jXGE$3D6*7wzK*PA;as^&8fYOjcT7Hp2YFTPg zr2;s_KtZa5JQNHXo&^;lpn-iz(E=Ji&CE$fDnyJx)h{z7PC<235hxL9GDD&f<a0=t z0p$#Ebbyn8(N5^tO)&>3>{+-N`IrP4tK=|46;u_$k`qcr1Tq?wwZQ=kng9R?7o#Kt zC_97ZKY|%Fnf$;RbR{D=-N78S29yv$l`X_QV3&c@1-Oj{N=(d*3=GAf#wi06A0w6q zEvPLAQitLokYhlO0^wqigTOTlXex)Jh7nZigC}87N>h*)kTu|<2~^;MMrLD$Y8g8~ zQ4jJ9gCqmA=wXBuIEO*O4Jl!Yz#dry@(tV%MH?9y7`PZ27>aNdC?KbU8+pY@g(oO6 zfQuq<h-fl_BkvY_W?o8Waw@dc0^0))CN6{-9H4*#`GSE7OO`;WLkTEQ$pi`{aLgru zTmD%LH6Y(IN-}_A4%D<OV=PjtVOYQfY6O5Qc!&xnNrqZxcua%FsafDW7D;H20hEnE zZLu1b8dh+-oVDtXGbk;WfaVckEn<blyh_kGqe5kVsX}>TUJ0lIPE1ZtEiP6_R7lP! z&C3Q2?1R$3LRwLNE@%WFo_#=dF-QhHCldlP7CdW^o0y%dP+VFBs+bc?Ksq3qHM1Bz zp^yrm2}n*XPF2V(QOGPtF)lA3JZVz`vJ^fS0nW6D+y=5XQ^6Ke*~4{1au}%MgQO8q z+SX(Sr+G-e1a+@qZAMU{xy6)L0Inzz>cKG&AwZE_)X%`czzd39P;Z%moe$JrWC8Ua z`IvBI2vBfnGP@OtgA$|!D4FsVff9BRXtb*cG?j6SJw84qKRG@g++QtP08+RZM1Y#6 zMIc9l%Y`CPp<HwfBmr{8EnaAkH#rqFdkAqOIKg*=lz?3;#=yV;8U`r_*~P)g!p)b; zq0Yg~!O5Y|A;7`P!CC}Tfpz8&)N3yWg&2JHl@nC$w=hOALuLx1SW-BG88o?y>KPaq zRx)M!fCk#OK-i2R_J5EwKx5j;pwS&rc?QBDHYjI+Q#okL6FeKk04o0D8A1JqcqZ_i z3^S-#0Pe)BWGV6k8RZWm0ziZ($QD*`Zp0>q6cpf40F45ufZ_~Xs6u=FzZ7u|jlc)V zQF13I0|Nud?O=Cb0Od|lBOKHZXIuzsXftIq6rCvn^=_C-n6p?HGSxDbFfU*O<<l5u zP`|yFwFIP^9W+k_DvQB1lO#h5lQ^i!%~k^HZL=+8tOd<dv6+Lr4q1!~xM~=(xS{H~ z!A)mS^PG7BPYp{A`$DE#7O<%xRlHDDyiipfB_LI-H5?0>YFSH|7x01l<Di@m?moa= z$`5jB4O1|~N@hPu0}fPktrUPXo7h0jn*7`%P>A1RPA)UL#gSZA;+&t8Uv!Jt+26-M zC_c#D$;cx(>=tWrK~84LEq0I&P_qaeXQ1ZQE%ws9Owh<!5hw<4F(;;^6oI1g7He*5 zQF<!O9FO2Iqab&umC%+BKPaL>z3N+x8MhdVi`If#Fr4sl(p#MH5z<?npmrp9^b|ej zLDh{mDCR*4je&`UNq~`$QGij7k&BUoQHqfVj6q^}Qv+mr-cOUI=oH9NXW#@Vaj#@W z%gvBX46+%?HQ;;+PK==3q6BghsDk5Q;$jEqpO>JO3Yv`I_9IHN10^|7#}u6GGQgc| z&`5I@6DYD5GNv%pFt&mUf2JC6n+Vk3T)<qzT*3hEW-~+7NifuaSWJ=-ni)#7NHQ#B zNdxsJS!!5HSZf$d*lHM>8Jn5Xn1UHJS&@>%OHht2S_LY+*o(leWe!j@7o-;D7J;(- zE#~CJl3T1*sYUt4x0rJ?^B}Dq)?3U;iRoZZK}L-;Q!0vYaf8Q?!OJR&Z*hb4#1|x{ zq@<=Gg%>!eKtmF?pr8VUGXt{(BL@o`qt<^m<|2>`6-q6n_yk!&VL4TF4wNv?g9vbR zg9)%lKn0K=$Ro%FP!XsQL5b^OU&It$1O?6|5OEhofI|pOfP)E?XF@^2L@2iwftJGI z$d5%&P)$Y{56XEyAmd><4>Wa>o0ypwA72D2^L{bvSLr~O^uX8e#1|ChXQd{W=qKl= zr0S=nCY6GR*E93<l8ei#<RGd+HM~Ati@uSmS-g>{S!r&SB}5(}1s(*=%uA0iNleN~ zE!GDYnehRR!SN-@xkd41nMH{?dY}oEDl>GgAWMo<OZ1a7Qj@b0cEp2?(1#9!7Keal zpi?U{i%a73vx~r^1-%Rm3~oi%puEdkgc#We4Ud6GJi*x$HZTiX%amGF{1Q~`XfoYm zhpc`oD*6P{&RFymL^<7J$uKRbECLN_LTU+iP!B05GYQh^29*y{+~Cnyz0#7*oZ=!- ztM-;Ol87G2H(&zP<4i3o0!0aU5-3UxtgI}vI5R&_4>1=~Q~@ggA*MnkqlCdK_44xc zU`^M=oT3?^*q;ef%jkBCIk~v(7IT2752V+B1tiQ?keHW}SbU4AAf+fCB*2%LlT@0U znpXmzxG65m1PO|v3hEV9f-Nfo4Vc{GfjcX;sJIA}r*E-k=B9(@<Uv^&GV#QenRtr_ ztO*u|pkn+MTYeHKJ%LwT-C}o1EJ-X*EdjT}!1cf_W>D1$D&D}g$}J{Pd750Fa*MSf zv8bf@7E4BcZfX=KI32}<d<U8HE6UF=0mtnv?)=h{g3=Pmq#XyS-JOzIR1A*jTkN1v zFHSAF#gvj$v<(zcpg!p>#^fk&aMve3IlnZo<Q5BP#JcD;NFFi%%9dD^o|}j~cnZ&0 zw>WHa!RycLK(oEYpfMK?CJsg(Mh+&>GzkwA4<iSo5NKKjHi-hBJ^_jHFbl8>)icyH sun0&C@C(RtfpvkVVtANAdij`Gm_%4Xd=XC2$O8`}2QwFw7#j~G06clsv;Y7A literal 0 HcmV?d00001 diff --git a/datasets/__pycache__/data_interface.cpython-39.pyc b/datasets/__pycache__/data_interface.cpython-39.pyc index 9550db1509e8d477a1c15534a76c6f87976fbec4..59584d31b16ee5ba7d08220a2174fb5e79e9aaed 100644 GIT binary patch literal 10339 zcmYe~<>g{vU|_gBWkb>#T?U57APzESWnf@%U|?V<wqj&pNMT4}%wdRv(2P-xDGVu0 zIZV0CQOt}GF_tLS6owS$9JXBcD0Z+IOAbdaXA~!x&6>lN%N@m?%M--|=CkGS=JG}H z<?=`IgZb<^0=a@wf?zgBj!>>}lrSSh3TFyej!3R(lqi_Zog<bj9wiQD^W;e6N=8Y7 z)$`^^<w{3MgV}sJGP$x*vS2oUj$E#MlsuR%kfV^R7^Rr26s44_9Hk5v6U<S`RgF?* zWN>Fl5lRtmVMq~9WzAA+W{y&KXGjrA5p7{e5rv9qxHF`PrHHpMq=-XBG~F3eBvK?> z7*ZrbBB_eaEKyqS3@K76(k%=r(y3a_%u(7YjKK_=GA}`)>Zi$gi^IR5Br`uRF-Mc} z7K=}Qet{<AEq1re5-^kL7Ee)PUP^v$d~rceX2~rsm&B4ppZvs>)FMr$TkIfSacape zw&2pF;?xpN##^kuexAW0nvA!&N{SNmiqrCoa*JOwGB7Y`GTst!OHC{(ElPDtOotl5 zcuUa3#m&^$(+6Zy4#Zr>TLNAI?&xAH0f{9UnvAzZ97{{`p$55DloTZ<m*f{|GTxH# z%uPznNz6-5O#zuzoLUl`m|KvOTBON*i`mi1`Ic~EQZZP0a!z7#ac*i!Mt;gIKA0ec z5AuMwTYgSTGAMMAF(}=E_{9+n3=F9ZptKgnlp@;Bkj9w8n!?t?8O5B!p2E?>5XF)r z21+9>j8Uv9+$lUQ3{h+;yeWJw3{mVU{3!x03{e~@f+<2R3{jjZ!YLvx3{hO|3@i*$ z+`$Z*;<tD~j`z$fNi9lCOiq0XO6o5;7#J9exEUB2{F2!~Zh&HtFf#)K14yPgjDdln zgrS*X0pmgjMur;38ishL8m1bCc;*`B8ish58kQP{c-9)$8ishb8nzmSc=j5G1sou? zC7cVm7BXaUr!Yz~)Uu~A)v(WINMW7}(#4a)63n2<TJ<BkxHLC6v8XbZi%S6t9E;M6 zt>8S6U{Yd9az=b{W>u<!MrK}#jzT^xVOc3er=%7q7iEG4V)YbUQqvMkb4rR8O7a!V zO!QEU$t%r`FV8Q^PAw`X$s7w*V^R`J5{pwy;`0)7Qx!CdONxkfkGeX-AaJPvV$`p4 z*RRY;%+n8Yc8)J7%FjwoF3~SaPL2mBrOdR<<U~*w(udkvtnZwdo0?OZpJ!->#MDbJ zF7wmmy~SEwT98_Fiz7ZhGcU6wKK_<KNl{{EUOXcFZV5s}kb>eCe@SX_39=diumG|W z9;p3bKi^`@E>A2<FTN!J%2x3psrcOdl+v8kB7O!2hFi?Z<tevV3lfV;if?g(WsCCj zOEfudu@tA~q!r14(wr=akOL9&AVL8|D1tPzWaQ_ju4KH$84n49`1q9!znt_zqL6?B zMPPA3Vsff}2&fqF_4LtoNlh(qFG@@+K?D=n+Duqt*DI(jl4f9FkOCF&BA`Uh!NtVJ zh=3f7e2i6MC>cZ#7Kq89G6tjvgh6Z&b_SWm!T_o$f*Dpa`W0z0FfgoSEK&n`1jGgl z6lsIJ#8Xg|T98<j3Jxg)kVYnuQ7lzbs1_nCO9m?eDPv$@0F@9RyTF!KFfcIGFs3kO zGZhQeFs3k-G8BcBFw`()F{Uu5u=Fz3GL<luFxN0OGp4YLFf=pPGQ;>a%n~3zO9@L2 zQw?(sOFC0CQw%60L$WSo3R^Z)aSGJT5~!K%ATvRkI)y2iL6f8El!5|0Pw6To=a=S{ zDCCvqCZ!fB<fkcsQ*CN7A}{MIB$a07q$rf-CnqMA<|GzXBFhz|7J-U!1#p_wFH6kP z2c<j(aD{+kNMb=jP9;bgwA@$FNJ&l0%u7vCNUbQy$xO~H$*EMx%u}e$FD+64#XTfF zD&!X_Bo-HErh`;w<|%-3g(lqjx0us1OZ-AKnTk|F(ZmXN#q(J+LExqREf$bl+%!3g zctCLqR>GW_S8|IvCpE7K6eYKK!8*Xv3aYBO%Mx?o93GHv7*CTGoG@;&7MCQZr`}=* zxe%m-G4U2-5+tR7(+oJduokD5loptQ(uo)-0q_VgvN1CKXJZmz6krx(7hnSMs)SL~ z4p<LHN(Gha;6wv52oj++3|UMy3|Y)6j44dLEWr#bnf=@}nQw6+6%pLXIiI~aqcklo zCsmWB2xMB3DJU=ySy&gu1-l(gfIR>$bU>a26%eUK;UK?&6EtI$9ICq!AqY`Qju(n- zL5%>Y`#|A_RK#F&A0I59AXY{pxlRVvb*Ks{cAGgUun>hFHn%~_VnoozA~{YL)p76y z2Jr@p>p+1GiY{>Ehk%MH#sv&1OexGQ;F4kiQwqyMMi+)=#)XWvj8$?a%q1)fSW{R* zrO0fCxl9WgYZ#kBB?DVHL!JmD0|=I|r?53MHZevr<S`jAq%Z_CXtMj=V!*H#R8R$j zTn#E`(-~?QVtHyAYZw<WEM%C-RLBy{08a0qW?XV&L1sx}PG(iACetk@J%d|}nP52x z0ZvL>HaVHaCCT}@1$Idc3=E$^`Ju`PH5@?R2WK>BC8TGQlb@WJQ*5V)(0q$Az6x9w z6{nUI>s1+e=9MMpWTq&9Oe#(-QE<!5NmcLz)#v#|3NT4$P-_k>stBsQvhp+Yia-_B zEmm*=b&EYcwIs2mr05oBW^O@#QAti_(k)&H6Y7*(+@P3*ak#<~b4pWPi;D7#G?~B! zLKIgrs5%GNst{kpb5W5J0|P@MC>%j~mw{1+k%^Iwk?TJbBh!CS-36l2GZT79pu`O* zIe_BE8)RkzxJ|c!5mXvvF)w7yVgZ#0txS>(wV=|0bphK#hFa!KhFX>yCKrZSsajC` zj=7AXNUMe+i?xPTlA(sBhSi3lP^*Sz0ecPOLPkbV!eUBe$YiKtS;E){>Ig7bomBA5 zD=taQOU?vGvqGXmZhlH?jzVd1W?s5NNk*zda(+=!YH>k+UJ6L0Ah9ShH?<_Ss2E;I zgPMP7Mftf3ptgfv23SS0LP};bs9BU)RH>&BnO~}qoS3JOm{XjukeOFdTB4AcSE&GM zp`e&jlCO|omReL^lv$FB#qKgt<bdiButks-4q9`~Pm|*oduCp7L27czEiO>(7o~$Z zMFF6U4QgQAVgXrO1Zl}+r&bo<Vk<33Ni0d#WCKU|Ew22MjMSodkg6yyaBT@;vt{Oi zlt;0GWDIXHCl;lnC5T(RnaE+81&T%iP<~*QVB}(CVdP-sU=(2FVyuz|C2Dv(Uk^o> zCTo#6$gL6}0@S((H$(k28T~YQK+Q%_lQus77FRqp&Bn*y;)#zhEKSUT$bef>MS39J zHXy<gL^y#6P)P$$vqh1h%9Iz{d`nKv$%&6g@`?sXF(_{rfzma|9mODDa4_+Saj<bI zaxil+bFgr5f@LwfO6(y2f;?V;+*J}zWlm*IWl3d8Wldq7!wT*=v8AxL@J6wvh@`To zGD|X~aHMd~VFJr=rEs_KMscKartr*R1b3fUqPU<uP$!BdiW|xUb){IMc)(pK-e3kz z(ObOGuA?ucwF?TsmqrW>3@<_DZOBSiNJD}vCow5CC%z!DB%{a<<XwAEh_EInCTFB3 z!@L0U2`2*s14t*R)m#iJ-w1cJQb64=HgIPP)S+UC=STr{r#MqUoiHv?2_Da#0_ueE zq_Ed8#Pg<b)G);Jr7+Ym#Pfs7?3J8;nGK8#3@<@;zXT146oJZumnT5(j0X`QPlad- z6oJ~xkmv?AeTqQNL-Gw*W^Q77D!5B_i!%?>R8K9w#aohCln&||fm?*P_;XVeLA{Uk z_?*<d^pcEQd}WD6@kxmYKB%dYl$c%|4{iwD;zjUNi%UQqhSZcIP%^m1la>ln18M(h za)NUcIJnVVR+I_y1V~E}sNMr-#UgN`Llmu`oLU7+MxeTwfsqTTBY@}<urabRu`yPO zp+q^nQ{cCfsmKBpt!zc9C8b4qDD8DnQ45M^P#3y*611pY$dJVZ?g%r1TG(X_MM)*h zH4IrSpaK;XhgobXEWIqX%q8q495u}F);kM~U&8`!y|b2Zg4*sitnk)5R}DivcMW3- zTNy)93zC`aH4Is7DU2x`y`c6zPYG``Ll$2RV+v<4sBq@;184n}OnycBpb!J~OI9+2 z1M?OOsFa2Tq#>wV!{k>4O66celLf4%s1&3L6xLDv;5rVRy20sTB_lXwH2I4_>AFZ6 zq)-J!@PP<KVghL{G6sp6fCzAl8Z@+%nFp@TZ*hUEXb77V-q?YpOHkpn5>ljsEh|z7 z8DR+`tU!b{h(L>Vke)tJtP6qi9wQ$U2csCH0HXk-7!wDh6jPNjYD|M$)GHZ_iWnGB z8#$ovA1H=FC0;S8C5&iTgF~OO8C0#Zq%eY-*P2Xzn#|yQbBnzgGHO%=Dh@PRA<bz} zZw@S7BnWacNU8`+yRZmU(9Q!n9#md1FjmQ-Iu@LJz(ECS1EdzA)cha^gF+YVU{I?J z;b2HlpqDwAVI`BFCKJRZAV<KO`;{P9!3#00P6HMH3z3{AgX%N_ZD@iHE2;w}L1^n9 zlxq==#OfT7FP9)WM;6sN@Pq^jC`}e{HYfrYeV~#sM3d2PC0`M!Q&1EQ@_Y;^mopca z6oJ#~EtaJG{2WNR2#Sg#P^yANB`AuDK!H&NazYWP<&L%9i(_D507ZK-sPyGvk`a*- z;A7@vE&_>z$}Ldpgkw-K0^$_otnt{OH6Cb)5V^)<rAmzl9+gF^?utM)Hn_?N@hbu~ zPm@7og&-G#Fvv%sPyuE9;tX)XM`+xT)C#QzRITxWt2KUbwI%?n){0U<T|hxd;Q%V$ zu~rQr7eR_OQ0u8E8x*}cAR-q;<bjBMP;7G~B_?Ml<>#e>tDYiI(68horp9Rj8G*Oj zSp`a&piTt-YKM&x(!qy>JiOK^Dg+fqpb!RyHVA_XB2XFuS2Z)hAr2a+MXPGkzy%Xh zRa3)|1sd34tz|CZ097bR)eV#nuWq=|svGVahIk%mb<>4pF1WgZnG0%ql<+k(WbxO4 zDj-mW!wIQyii$yD1gdQ?YLX&w@PTTaqB4*oaGg^QidX^ks)QFLk5-j{b)(iNMWB3E z1TMqCB?q`o1J1G4AOo;e5c@#!gQy^+;1z@nbCm$7VGWHDaP3f31Bwn%@d}Cx5C*Z~ zB^t;WL{SUsV1Wu+P0)Y~V-aZhU6UD7a)4b4E@W{P{bxbW0>?YFeGhd8!Ezte;{z2( z;6^sGV|tk&Wj-UgIEA?c98@@60xH}tAh|>mDL@Dm^LX7+)CP)VXkdUN3a1l5ez}a~ z1n>|lEHFUB*r3J;v}k9A6zxbQmL{WLQ3ELE!EHe>0Zs$p{zx+e1H%fC(V$WcTrTs2 z%Vm%_M!5{~4=4(YL2Vb<q)jSwD!3m#hY2(k+rrY!7{v+|W1Yj40_u~1#Mr>oWE{7I zof3;v!E-{0aP(Wrbc++z2u{n-NhyXjKEUw~N?p+I0w|Gz!l4-C2yp4A1M|~L$hbmL zCrkiR1A+W^iyNF~<G}{s0QnPaE>o2}7H8-osa?r*iwmSFJ~J;RwSr_PL5Dhuda*c3 z0AUTp^gBrIQo-*ohz3n|w5qxY)W88()kWPP8$m5TNFact8%%(Lqo|#Mf#EvHI#311 z!N|lY#LUD97R88cPyq`H|Kc9xZl@$tw-db*mV$OW!If|dsLRP3#gWRA!kWS}hY8y0 z<OK8i=P;!RqzJZvI-%StLUS0wGj}XeJgE?#Xo^@1OB64ZC!QkF!V<*??vnBcGiXZR z5_2xfFD?#C%!x;7=U^m#0+lr=S5UwAiLJ7(VTcz@0rf|Pz@#vk6agn=a3wBEok|>1 znSi1aTz^Aq0#Mm<3);s3CFxrnpkfQeDryBK7I0yJNFtzy#3N7=0f`f;u-Vukea9+! z>@k9;5(njDP(=>T)NUA+I8r4G8k}QGVFvXR!IQ7-NOGX+9V}M^8WjMI&$2?RTu^OW zR08rKs7lr3M6XTxKvHP6DY#Aog_kA^q~!zZ?ILxvz@Y>tz_|%jsr~~66(~0`F!C{r zv9K{$$>0bIaJDb%0M(x0LI{+(poJSe96*Ci;K^nNP|exGPy!m0W?aAo8mk8NnVT6w zqt~D@Y|z9rDE(+M`=RuEKt_YwMc~nB&>&MfXf&DwG#ZUG5)CdKG?|J(;a=ne3JI{U zzy#Q99t;c&{EQ3?MV<@{3{__MyaXv+u?;|LvO~&bQIG{7TW+xyCFZ547J=#yaN@kh z1|C%|F6sg)0b36yz-|B)WxR|G4B#T4fsu`|N{O&Lia^B`%D@CDp@BjKoY24l%#g)U z!zjrR!&J*u%UlBLxH8o+frtKKJZ3PD1r)$Vpn67=wP-TPg`kKAJG^KDh&vHfg+r4H zsHlfcj%DVhWG1H;M}a5Dpd4_JLkO_%KxMQb$akQ~XJF(5O~$Gd^dF=Gz}^M}83gtu zG5vF<qFRs-z};SGw;b#xa3h3pw_F_La!^`hV64LBYJ%-Fc%u<ieIeRui2gaa8w~4} zg901eZUc`cVC@Ej+G^5BE>p!3o}dB{(Pl!)KcIF7aqcRr2L&Lu#w#ci<dB@9hQk@? z-EypLH%$Tb%Dxd~FSyO$1Y&^->>^N9K)U6iTvG%}T##lqu^se&1_p))pt1~9E^;u* z$cXU^un4g6vG6e$f#fuGih@A4`GNY1Y$cTinR)48TW<*#RF>oyC1=FvWTt17<Ynfi z-(o7riITw(*2~GyFDTYaN-Rz-0%gc31zbwfGE3q?ii<#bH%b6(yk2QZW=^plDB8di z(O~aINubMt6%?nI6oJ|=MWFl|B?4AcmRX#cp9hmJE&`1OM{y%b6@i+-Q38<3DUdag z7BzS#JW4RRw74Wc7iqMq2-F>m5`~Iq6{I7}fdcUsJ7_RGCo>6L^F;B$7jUH(6&Ha@ z<|rPp?Zwa$GEk<w#hjRwTm<Sv6$yhv8<a6_i5F+*q$U>S>48QlbK;9rb5g+*K}Dc) zHA>hA7VL-yHFzQcJhoB<YRH2dc;Hxslt17i7EFLjqFWp`kc9(wptN5MD(*njDm;uF dj694yOrXI+9uNzGnM4?QScDRU421MWI01I;6;=QM delta 2455 zcmaDH@YqZ+k(ZZ?fq{YHhHOVtyaEHm<H>e1l1zsfCh8c}=W;}GFfybtq_E_0=5j@G zf!VA%+_^kaJd6w}Y$@zHyt#Z)d|)<54u7sdlmJ)_XO3X5P?Qju&6OjZD-tCFW^?C= z=88p$F*3L_r0}HhwlJjdrm|*<H#0{`xHF{irSP{fr0_#UB;7R`QUp>2TNqLVp+ZtA zjKK_=LN7t?_0wd$#o=F2l9`{Em@~PbQFiivM)%25OqPtv69q*k&t<A*WZx{!T*}DE zK6xU`6-KVfX{^dp!VC-yMIs<V6hw%D2yqZ0F?lg-EUOR$1B1Y1R<<x*O-8>WMFs|j zm5fEQAbAiQEKsDxz`$^er=TdcAh9SlJ|(dvQEhVv+Z0AdnaQ#ovZ5Ry)iNM8%$a#5 zx0rKM^NN@!`*RpD%1&<JkP<^Ph_yJiq_jY1@&=AcRxBW`+8_eqOl1%kO^-lHQDSBu z*y^18#FW$`-^n$c@%F|H3=EnqMPO5OK-%z|!dI4<gJed)WGSv#2V^sJL1rL=37bp! zOHzwVkPHc)yq)VZqvGUlZfR*WM|0WaWEPhs=jRsKMNB@(-O4CF*`3Ez6%@EpT*-;K zsX6g^iMgqeFb3-d1$2=l0|P_&=B+$kpm=oPlhJYmxe63dw^$O3(u<2EKtautomyFZ zi><UEC9xz`lMNhYQIp&FbmYJ$feDatw|F!2ic1pnk~0%?GOJP(CSTyQ5a0uu$PXge z!Ng<<e!VP^4Mi#-z7dEp2N4z^!V*M)Tvg-;vY9tNJ~=0`xHvgACnr80$@TIebs#&7 zco`TNKz0-hFfcH1F!70TuyHAJFmo_-uyAk|@lM{rZ#MacSe!GjucwboVo9QJeoARh zY7y8mAZNW~Vqjnh@mtAS1Ws05If+TBIq?OFB^gB~AWKX^?qE$$OwLH194_AGv69m- zvw@L;;Uy;n1H(&D<}L~Wu}**pux~&{hpZGV5(H@n`xInyQ5c8|H!^CnpP*QM6mLml zQF>}gd|qj8Qfg5Ye{O1GUVKtwdVEf5UV2GJ6kl0l5eh%AG#A7#jt8gGC|(3VwYcOK zOL1yy$}Nth#N_Ox{Jhkna0Uj3D4w)bkm|hRwEUvnm0aNH1ADf}5!FpaaUdsx^b{q4 zSY99klmd#t!3~e+B9NyF7#J9e6(-v%D6>|H`Fi?H-Y6l=vXZGtZ*s7Z>g2}~bLv-e z`hkO4lgY106=akeBnyEoy~UE7pOOl(5Nww^NRG*`NCTu0L})S=>48cJfg*4iOM%3t zK?Eq=A<hRGgk*#!NLC9(FoWW3@={^taBv)eVr3;H*@De1k^?Ey2N4D!!VpBDc@d<i z1>{ALUl|zrm^c`v7zG#w7{!=47-c5Qi%2piPqq}1oqT{(%0ZLaFJvWC5h$q?nSmS) zEho}J2E(Jg2umI>0ww)PlMjm|S;Ez=gzC;9tQ(YPrcRC%jW)?cvNn^jMo>PRG5M{i zRy`{?j*3880i2LS{4^Q;R`M13gA54(5rLr4WG*f#DguS@EtaJG{2Xu+C<0}jB2Y*{ zibarLia<^(0vTTv1kwynG++W8z~C$!%)r0^a&Ivg0|PGu2a}A56dxZmA9GPKG`D66 zf$|Wtn~FjwGs??OE)>tzhyWQI3nCIh1lU-x1z@v5(KQbgZ{VakdA)=fOA-SE!{j(2 zDNv%cD@p-rN&^vDAOcN0$TOQj+Ck=EL_5=DM=4>(Y*@T&Xfpe`73F}8$psO(;`hMh zeNq;a=cy<ugM3|-k4r5m=nhZzkTyf81?4iV=>`<S$0nba)(k=qDp1{82(khkC}0BY zNpSF#FfcI80r?M<sgOb*B(5o4<N?y>GC5dHLqeb=zbH9Fue2mHr&teE4;O)QRT0Z% zMs+D+R#3UZ4OUT>S)7@lSHw11PhHFxR6R!tCYKhM<mVz)Tt%Qfdy5@ZsODrQfirXz zPf1ZCq)19FDlXy%8OQ@RusAWdASbn`h-dPBb;&?b@m>T<Fhw9IBGmw3kAeJji^C=t fT&CMGf;<K)YdIJ>7<m{u7zG%4kRg*$f{-2nz3A|R diff --git a/datasets/custom_dataloader.py b/datasets/custom_dataloader.py index 02850f5..cf65534 100644 --- a/datasets/custom_dataloader.py +++ b/datasets/custom_dataloader.py @@ -41,7 +41,7 @@ class HDF5MILDataloader(data.Dataset): data_cache_size: Number of HDF5 files that can be cached in the cache (default=3). """ - def __init__(self, file_path, label_path, mode, n_classes, backbone=None, load_data=False, data_cache_size=20, bag_size=1024): + def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=10, bag_size=1024): super().__init__() self.data_info = [] @@ -55,7 +55,6 @@ class HDF5MILDataloader(data.Dataset): self.label_path = label_path self.n_classes = n_classes self.bag_size = bag_size - self.backbone = backbone # self.label_file = label_path recursive = True @@ -134,10 +133,6 @@ class HDF5MILDataloader(data.Dataset): RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)), transforms.ToTensor() ]) - if self.backbone == 'dino': - self.feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16') - - # self._add_data_infos(load_data) def __getitem__(self, index): # get data @@ -150,9 +145,6 @@ class HDF5MILDataloader(data.Dataset): # print(img.shape) for img in batch: # expects numpy img = img.numpy().astype(np.uint8) - if self.backbone == 'dino': - img = self.feature_extractor(images=img, return_tensors='pt') - # img = self.resize_transforms(img) img = seq_img_d.augment_image(img) img = self.val_transforms(img) out_batch.append(img) @@ -160,23 +152,24 @@ class HDF5MILDataloader(data.Dataset): else: for img in batch: img = img.numpy().astype(np.uint8) - if self.backbone == 'dino': - img = self.feature_extractor(images=img, return_tensors='pt') - img = self.resize_transforms(img) - img = self.val_transforms(img) out_batch.append(img) - if len(out_batch) == 0: - # print(name) - out_batch = torch.randn(self.bag_size,3,256,256) - else: out_batch = torch.stack(out_batch) + # if len(out_batch) == 0: + # # print(name) + # out_batch = torch.randn(self.bag_size,3,256,256) + # else: + out_batch = torch.stack(out_batch) # print(out_batch.shape) # out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch - + # print(out_batch.shape) + if out_batch.shape != torch.Size([self.bag_size, 256, 256, 3]) and out_batch.shape != torch.Size([self.bag_size, 3,256,256]): + print(name) + print(out_batch.shape) + out_batch = torch.permute(out_batch, (0, 2,1,3)) label = torch.as_tensor(label) label = torch.nn.functional.one_hot(label, num_classes=self.n_classes) - return out_batch, label, name + return out_batch, label, name #, name_batch def __len__(self): return len(self.data_info) @@ -184,7 +177,9 @@ class HDF5MILDataloader(data.Dataset): def _add_data_infos(self, file_path, load_data): wsi_name = Path(file_path).stem if wsi_name in self.slideLabelDict: + # if wsi_name[:2] != 'RU': #skip RU because of container problems in dataset label = self.slideLabelDict[wsi_name] + # print(wsi_name) idx = -1 self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx}) @@ -195,14 +190,29 @@ class HDF5MILDataloader(data.Dataset): """ with h5py.File(file_path, 'r') as h5_file: wsi_batch = [] + tile_names = [] for tile in h5_file.keys(): + img = h5_file[tile][:] img = img.astype(np.uint8) img = torch.from_numpy(img) # img = self.resize_transforms(img) + wsi_batch.append(img) - wsi_batch = torch.stack(wsi_batch) - wsi_batch, _ = to_fixed_size_bag(wsi_batch, self.bag_size) + tile_names.append(tile) + + # print('Empty Container: ', file_path) #/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/RU0248_PASD_jke_PASD_20200201_195900_BIG.hdf5 + + if wsi_batch: + wsi_batch = torch.stack(wsi_batch) + else: + print('Empty Container: ', file_path) + wsi_batch = torch.randn(self.bag_size,3,256,256) + + if wsi_batch.shape[1:] != torch.Size([3, 256, 256]) and wsi_batch.shape[1:] != torch.Size([256, 256, 3]): + print(file_path) + print(wsi_batch.shape) + wsi_batch, name_batch = to_fixed_size_bag(wsi_batch, self.bag_size) idx = self._add_to_cache(wsi_batch, file_path) file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path) self.data_info[file_idx + idx]['cache_idx'] = idx @@ -461,16 +471,19 @@ class RandomHueSaturationValue(object): img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) return img #, lbl -def to_fixed_size_bag(bag: torch.Tensor, bag_size: int = 512): +def to_fixed_size_bag(bag, bag_size: int = 512): # get up to bag_size elements bag_idxs = torch.randperm(bag.shape[0])[:bag_size] bag_samples = bag[bag_idxs] + # bag_sample_names = [bag_names[i] for i in bag_idxs] # zero-pad if we don't have enough samples zero_padded = torch.cat((bag_samples, torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3]))) + # zero_padded_names = bag_sample_names + ['']*(bag_size - len(bag_sample_names)) return zero_padded, min(bag_size, len(bag)) + # return zero_padded, zero_padded_names, min(bag_size, len(bag)) class RandomHueSaturationValue(object): @@ -510,11 +523,11 @@ if __name__ == '__main__': train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv' data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/' # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json' - label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json' + label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_no_other.json' output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/augments' os.makedirs(output_path, exist_ok=True) - n_classes = 2 + n_classes = 5 dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes, bag_size=20) # print(dataset.dataset) @@ -528,15 +541,19 @@ if __name__ == '__main__': # print(len(dataset)) # # x = 0 + #/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/RU0248_PASD_jke_PASD_20200201_195900_BIG.hdf5 c = 0 label_count = [0] *n_classes for item in dl: - # if c >=10: - # break + if c >=10: + break bag, label, name = item - label_count[np.argmax(label)] += 1 - print(label_count) - print(len(train_ds)) + print(name) + # if name == 'RU0248_PASD_jke_PASD_20200201_195900_BIG': + + # print(bag) + # print(label) + c += 1 # # # print(bag.shape) # # if bag.shape[1] == 1: # # print(name) @@ -578,7 +595,7 @@ if __name__ == '__main__': # o_img = Image.fromarray(o_img) # o_img = o_img.convert('RGB') # o_img.save(f'{output_path}/{i}_original.png') - # c += 1 + # break # else: break # print(data.shape) diff --git a/datasets/custom_jpg_dataloader.py b/datasets/custom_jpg_dataloader.py new file mode 100644 index 0000000..722b275 --- /dev/null +++ b/datasets/custom_jpg_dataloader.py @@ -0,0 +1,459 @@ +''' +ToDo: remove bag_size +''' + + +import numpy as np +from pathlib import Path +import torch +from torch.utils import data +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm +import torchvision.transforms as transforms +from PIL import Image +import cv2 +import json +import albumentations as A +from imgaug import augmenters as iaa +import imgaug as ia +from torchsampler import ImbalancedDatasetSampler + + +class RangeNormalization(object): + def __call__(self, sample): + img = sample + return (img / 255.0 - 0.5) / 0.5 + +class JPGMILDataloader(data.Dataset): + """Represents an abstract HDF5 dataset. For single H5 container! + + Input params: + file_path: Path to the folder containing the dataset (one or multiple HDF5 files). + mode: 'train' or 'test' + load_data: If True, loads all the data immediately into RAM. Use this if + the dataset is fits into memory. Otherwise, leave this at false and + the data will load lazily. + data_cache_size: Number of HDF5 files that can be cached in the cache (default=3). + + """ + def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=10, bag_size=1024): + super().__init__() + + self.data_info = [] + self.data_cache = {} + self.slideLabelDict = {} + self.files = [] + self.data_cache_size = data_cache_size + self.mode = mode + self.file_path = file_path + # self.csv_path = csv_path + self.label_path = label_path + self.n_classes = n_classes + self.bag_size = bag_size + self.empty_slides = [] + # self.label_file = label_path + recursive = True + + # read labels and slide_path from csv + with open(self.label_path, 'r') as f: + temp_slide_label_dict = json.load(f)[mode] + for (x, y) in temp_slide_label_dict: + x = Path(x).stem + + # x_complete_path = Path(self.file_path)/Path(x) + for cohort in Path(self.file_path).iterdir(): + x_complete_path = Path(self.file_path) / cohort / 'BLOCKS' / Path(x) + if x_complete_path.is_dir(): + if len(list(x_complete_path.iterdir())) > 50: + # print(x_complete_path) + self.slideLabelDict[x] = y + self.files.append(x_complete_path) + else: self.empty_slides.append(x_complete_path) + # print(len(self.empty_slides)) + # print(self.empty_slides) + + + for slide_dir in tqdm(self.files): + self._add_data_infos(str(slide_dir.resolve()), load_data) + + + self.resize_transforms = A.Compose([ + A.SmallestMaxSize(max_size=256) + ]) + sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1") + sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2") + sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3") + sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4") + sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5") + + self.train_transforms = iaa.Sequential([ + iaa.AddToHueAndSaturation(value=(-13, 13), name="MyHSV"), + sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")), + iaa.Fliplr(0.5, name="MyFlipLR"), + iaa.Flipud(0.5, name="MyFlipUD"), + sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")), + iaa.OneOf([ + sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")), + sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")), + sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine")) + ], name="MyOneOf") + + ], name="MyAug") + + # self.train_transforms = A.Compose([ + # A.HueSaturationValue(hue_shift_limit=13, sat_shift_limit=2, val_shift_limit=0, always_apply=True, p=1.0), + # # A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=0, val_shift_limit=0, always_apply=False, p=0.5), + # # A.RandomGamma(), + # # A.HorizontalFlip(), + # # A.VerticalFlip(), + # # A.RandomRotate90(), + # # A.OneOf([ + # # A.ElasticTransform(alpha=150, sigma=20, alpha_affine=50), + # # A.Affine( + # # scale={'x': (0.95, 1.05), 'y': (0.95, 1.05)}, + # # rotate=(-45, 45), + # # shear=(-4, 4), + # # cval=8, + # # ) + # # ]), + # A.Normalize(), + # ToTensorV2(), + # ]) + self.val_transforms = transforms.Compose([ + # A.Normalize(), + # ToTensorV2(), + RangeNormalization(), + transforms.ToTensor(), + + ]) + self.img_transforms = transforms.Compose([ + transforms.RandomHorizontalFlip(p=1), + transforms.RandomVerticalFlip(p=1), + # histoTransforms.AutoRandomRotation(), + transforms.Lambda(lambda a: np.array(a)), + ]) + self.hsv_transforms = transforms.Compose([ + RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)), + transforms.ToTensor() + ]) + + def __getitem__(self, index): + # get data + batch, label, name = self.get_data(index) + out_batch = [] + seq_img_d = self.train_transforms.to_deterministic() + + if self.mode == 'train': + # print(img) + # print(.shape) + for img in batch: # expects numpy + img = img.numpy().astype(np.uint8) + # print(img.shape) + img = seq_img_d.augment_image(img) + img = self.val_transforms(img) + out_batch.append(img) + + else: + for img in batch: + img = img.numpy().astype(np.uint8) + img = self.val_transforms(img) + out_batch.append(img) + + # if len(out_batch) == 0: + # # print(name) + # out_batch = torch.randn(self.bag_size,3,256,256) + # else: + out_batch = torch.stack(out_batch) + # print(out_batch.shape) + # out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch + # print(out_batch.shape) + # if out_batch.shape != torch.Size([self.bag_size, 256, 256, 3]) and out_batch.shape != torch.Size([self.bag_size, 3,256,256]): + # print(name) + # print(out_batch.shape) + # out_batch = torch.permute(out_batch, (0, 2,1,3)) + label = torch.as_tensor(label) + label = torch.nn.functional.one_hot(label, num_classes=self.n_classes) + # print(out_batch) + return out_batch, label, name #, name_batch + + def __len__(self): + return len(self.data_info) + + def _add_data_infos(self, file_path, load_data): + wsi_name = Path(file_path).stem + if wsi_name in self.slideLabelDict: + # if wsi_name[:2] != 'RU': #skip RU because of container problems in dataset + label = self.slideLabelDict[wsi_name] + # print(wsi_name) + idx = -1 + self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx}) + + def _load_data(self, file_path): + """Load data to the cache given the file + path and update the cache index in the + data_info structure. + """ + wsi_batch = [] + tile_names = [] + # print(wsi_batch) + # for tile_path in Path(file_path).iterdir(): + # print(tile_path) + for tile_path in Path(file_path).iterdir(): + # print(tile_path) + img = np.asarray(Image.open(tile_path)).astype(np.uint8) + img = torch.from_numpy(img) + + # print(wsi_batch) + wsi_batch.append(img) + + tile_names.append(tile_path.stem) + + # if wsi_batch: + wsi_batch = torch.stack(wsi_batch) + if len(wsi_batch.shape) < 4: + wsi_batch.unsqueeze(0) + # else: + # print('Empty Container: ', file_path) + # self.empty_slides.append(file_path) + # wsi_batch = torch.randn(self.bag_size,256,256,3) + # print(wsi_batch.shape) + # if wsi_batch.shape[1:] != torch.Size([3, 256, 256]) and wsi_batch.shape[1:] != torch.Size([256, 256, 3]): + # print(file_path) + # print(wsi_batch.shape) + # wsi_batch, name_batch = to_fixed_size_bag(wsi_batch, self.bag_size) + idx = self._add_to_cache(wsi_batch, file_path) + file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path) + self.data_info[file_idx + idx]['cache_idx'] = idx + + # remove an element from data cache if size was exceeded + if len(self.data_cache) > self.data_cache_size: + # remove one item from the cache at random + removal_keys = list(self.data_cache) + removal_keys.remove(file_path) + self.data_cache.pop(removal_keys[0]) + # remove invalid cache_idx + # self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'shape': di['shape'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info] + self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info] + + def _add_to_cache(self, data, data_path): + """Adds data to the cache and returns its index. There is one cache + list for every file_path, containing all datasets in that file. + """ + if data_path not in self.data_cache: + self.data_cache[data_path] = [data] + else: + self.data_cache[data_path].append(data) + return len(self.data_cache[data_path]) - 1 + + # def get_data_infos(self, type): + # """Get data infos belonging to a certain type of data. + # """ + # data_info_type = [di for di in self.data_info if di['type'] == type] + # return data_info_type + + def get_name(self, i): + # name = self.get_data_infos(type)[i]['name'] + name = self.data_info[i]['name'] + return name + + def get_labels(self, indices): + + return [self.data_info[i]['label'] for i in indices] + # return self.slideLabelDict.values() + + def get_data(self, i): + """Call this function anytime you want to access a chunk of data from the + dataset. This will make sure that the data is loaded in case it is + not part of the data cache. + i = index + """ + # fp = self.get_data_infos(type)[i]['data_path'] + fp = self.data_info[i]['data_path'] + if fp not in self.data_cache: + self._load_data(fp) + + # get new cache_idx assigned by _load_data_info + # cache_idx = self.get_data_infos(type)[i]['cache_idx'] + cache_idx = self.data_info[i]['cache_idx'] + label = self.data_info[i]['label'] + name = self.data_info[i]['name'] + # print(self.data_cache[fp][cache_idx]) + return self.data_cache[fp][cache_idx], label, name + + + +class RandomHueSaturationValue(object): + + def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5): + + self.hue_shift_limit = hue_shift_limit + self.sat_shift_limit = sat_shift_limit + self.val_shift_limit = val_shift_limit + self.p = p + + def __call__(self, sample): + + img = sample #,lbl + + if np.random.random() < self.p: + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32 + h, s, v = cv2.split(img) + hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1) + hue_shift = np.uint8(hue_shift) + h += hue_shift + sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1]) + s = cv2.add(s, sat_shift) + val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1]) + v = cv2.add(v, val_shift) + img = cv2.merge((h, s, v)) + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) + return img #, lbl + +def to_fixed_size_bag(bag, bag_size: int = 512): + + #duplicate bag instances unitl + + # get up to bag_size elements + bag_idxs = torch.randperm(bag.shape[0])[:bag_size] + bag_samples = bag[bag_idxs] + # bag_sample_names = [bag_names[i] for i in bag_idxs] + q, r = divmod(bag_size, bag_samples.shape[0]) + if q > 0: + bag_samples = torch.cat([bag_samples]*q, 0) + + self_padded = torch.cat([bag_samples, bag_samples[:r,:, :, :]]) + + # zero-pad if we don't have enough samples + # zero_padded = torch.cat((bag_samples, + # torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3]))) + + return self_padded, min(bag_size, len(bag)) + + +class RandomHueSaturationValue(object): + + def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5): + + self.hue_shift_limit = hue_shift_limit + self.sat_shift_limit = sat_shift_limit + self.val_shift_limit = val_shift_limit + self.p = p + + def __call__(self, sample): + + img = sample #,lbl + + if np.random.random() < self.p: + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32 + h, s, v = cv2.split(img) + hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1) + hue_shift = np.uint8(hue_shift) + h += hue_shift + sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1]) + s = cv2.add(s, sat_shift) + val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1]) + v = cv2.add(v, val_shift) + img = cv2.merge((h, s, v)) + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) + return img #, lbl + + + +if __name__ == '__main__': + from pathlib import Path + import os + + home = Path.cwd().parts[1] + train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv' + data_root = f'/{home}/ylan/data/DeepGraft/256_256um' + # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json' + label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json' + output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/augments' + os.makedirs(output_path, exist_ok=True) + + n_classes = 2 + + dataset = JPGMILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes, bag_size=20) + # print(dataset.dataset) + # a = int(len(dataset)* 0.8) + # b = int(len(dataset) - a) + # train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b]) + dl = DataLoader(dataset, None, num_workers=1) + print(len(dl)) + dl = DataLoader(dataset, None, sampler=ImbalancedDatasetSampler(dataset), num_workers=5) + + + + # data = DataLoader(dataset, batch_size=1) + + # print(len(dataset)) + # # x = 0 + #/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/RU0248_PASD_jke_PASD_20200201_195900_BIG.hdf5 + c = 0 + label_count = [0] *n_classes + print(len(dl)) + for item in dl: + # if c >=10: + # break + bag, label, name = item + # print(label) + label_count[torch.argmax(label)] += 1 + # print(name) + # if name == 'RU0248_PASD_jke_PASD_20200201_195900_BIG': + + # print(bag) + # print(label) + c += 1 + print(label_count) + # # # print(bag.shape) + # # if bag.shape[1] == 1: + # # print(name) + # # print(bag.shape) + # print(bag.shape) + + # # out_dir = Path(output_path) / name + # # os.makedirs(out_dir, exist_ok=True) + + # # # print(item[2]) + # # # print(len(item)) + # # # print(item[1]) + # # # print(data.shape) + # # # data = data.squeeze() + # # bag = item[0] + # bag = bag.squeeze() + # original = original.squeeze() + # for i in range(bag.shape[0]): + # img = bag[i, :, :, :] + # img = img.squeeze() + + # img = ((img-img.min())/(img.max() - img.min())) * 255 + # print(img) + # # print(img) + # img = img.numpy().astype(np.uint8).transpose(1,2,0) + + + # img = Image.fromarray(img) + # img = img.convert('RGB') + # img.save(f'{output_path}/{i}.png') + + + + # o_img = original[i,:,:,:] + # o_img = o_img.squeeze() + # print(o_img.shape) + # o_img = ((o_img-o_img.min())/(o_img.max()-o_img.min()))*255 + # o_img = o_img.numpy().astype(np.uint8).transpose(1,2,0) + # o_img = Image.fromarray(o_img) + # o_img = o_img.convert('RGB') + # o_img.save(f'{output_path}/{i}_original.png') + + # break + # else: break + # print(data.shape) + # print(label) + # a = [torch.Tensor((3,256,256))]*3 + # b = torch.stack(a) + # print(b) + # c = to_fixed_size_bag(b, 512) + # print(c) \ No newline at end of file diff --git a/datasets/data_interface.py b/datasets/data_interface.py index 056e6ff..efa104c 100644 --- a/datasets/data_interface.py +++ b/datasets/data_interface.py @@ -2,15 +2,25 @@ import inspect # 查看python 类的参数和模块、函数代码 import importlib # In order to dynamically import the library from typing import Optional import pytorch_lightning as pl +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.loops.fit_loop import FitLoop + from torch.utils.data import random_split, DataLoader +from torch.utils.data.dataset import Dataset, Subset from torchvision.datasets import MNIST from torchvision import transforms from .camel_dataloader import FeatureBagLoader from .custom_dataloader import HDF5MILDataloader +from .custom_jpg_dataloader import JPGMILDataloader from pathlib import Path from transformers import AutoFeatureExtractor from torchsampler import ImbalancedDatasetSampler +from abc import ABC, abstractclassmethod, abstractmethod +from sklearn.model_selection import KFold + + + class DataInterface(pl.LightningDataModule): def __init__(self, train_batch_size=64, train_num_workers=8, test_batch_size=1, test_num_workers=1,dataset_name=None, **kwargs): @@ -109,7 +119,7 @@ class DataInterface(pl.LightningDataModule): class MILDataModule(pl.LightningDataModule): - def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, backbone=None, *args, **kwargs): + def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=50, n_classes=2, cache: bool=True, *args, **kwargs): super().__init__() self.data_root = data_root self.label_path = label_path @@ -124,27 +134,29 @@ class MILDataModule(pl.LightningDataModule): self.num_bags_test = 50 self.seed = 1 - self.backbone = backbone self.cache = True self.fe_transform = None + def setup(self, stage: Optional[str] = None) -> None: home = Path.cwd().parts[1] if stage in (None, 'fit'): - dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes, backbone=self.backbone) + dataset = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes) a = int(len(dataset)* 0.8) b = int(len(dataset) - a) self.train_data, self.valid_data = random_split(dataset, [a, b]) if stage in (None, 'test'): - self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, backbone=self.backbone) + self.test_data = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, data_cache_size=1) return super().setup(stage=stage) + + def train_dataloader(self) -> DataLoader: - return DataLoader(self.train_data, self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, + return DataLoader(self.train_data, batch_size = self.batch_size, sampler=ImbalancedDatasetSampler(self.train_data), num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True, #sampler=ImbalancedDatasetSampler(self.train_data) def val_dataloader(self) -> DataLoader: return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers) @@ -187,13 +199,92 @@ class DataModule(pl.LightningDataModule): if stage in (None, 'test'): self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, backbone=self.backbone) + return super().setup(stage=stage) def train_dataloader(self) -> DataLoader: - return DataLoader(self.train_data, self.batch_size, num_workers=self.num_workers, shuffle=False,) #batch_transforms=self.transform, pseudo_batch_dim=True, + return DataLoader(self.train_data, self.batch_size, shuffle=False,) #batch_transforms=self.transform, pseudo_batch_dim=True, #sampler=ImbalancedDatasetSampler(self.train_data), def val_dataloader(self) -> DataLoader: - return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers) + return DataLoader(self.valid_data, batch_size = self.batch_size) + + def test_dataloader(self) -> DataLoader: + return DataLoader(self.test_data, batch_size = self.batch_size) #, num_workers=self.num_workers + + +class BaseKFoldDataModule(pl.LightningDataModule, ABC): + @abstractmethod + def setup_folds(self, num_folds: int) -> None: + pass + + @abstractmethod + def setup_fold_index(self, fold_index: int) -> None: + pass + +class CrossVal_MILDataModule(BaseKFoldDataModule): + + def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, backbone=None, *args, **kwargs): + super().__init__() + self.data_root = data_root + self.label_path = label_path + self.batch_size = batch_size + self.num_workers = num_workers + self.image_size = 384 + self.n_classes = n_classes + self.target_number = 9 + self.mean_bag_length = 10 + self.var_bag_length = 2 + self.num_bags_train = 200 + self.num_bags_test = 50 + self.seed = 1 + + self.backbone = backbone + self.cache = True + self.fe_transform = None + + # train_dataset: Optional[Dataset] = None + # test_dataset: Optional[Dataset] = None + # train_fold: Optional[Dataset] = None + # val_fold: Optional[Dataset] = None + self.train_data : Optional[Dataset] = None + self.test_data : Optional[Dataset] = None + self.train_fold : Optional[Dataset] = None + self.val_fold : Optional[Dataset] = None + + def setup(self, stage: Optional[str] = None) -> None: + home = Path.cwd().parts[1] + + # if stage in (None, 'fit'): + dataset = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes) + # a = int(len(dataset)* 0.8) + # b = int(len(dataset) - a) + # self.train_data, self.val_data = random_split(dataset, [a, b]) + self.train_data = dataset + + # if stage in (None, 'test'):, + self.test_data = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes) + + # return super().setup(stage=stage) + + def setup_folds(self, num_folds: int) -> None: + self.num_folds = num_folds + self.splits = [split for split in KFold(num_folds).split(range(len(self.train_data)))] + + def setup_fold_index(self, fold_index: int) -> None: + train_indices, val_indices = self.splits[fold_index] + self.train_fold = Subset(self.train_data, train_indices) + self.val_fold = Subset(self.train_data, val_indices) + + + def train_dataloader(self) -> DataLoader: + return DataLoader(self.train_fold, self.batch_size, sampler=ImbalancedDatasetSampler(self.train_fold), num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True, + # return DataLoader(self.train_fold, self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, + #sampler=ImbalancedDatasetSampler(self.train_data) + def val_dataloader(self) -> DataLoader: + return DataLoader(self.val_fold, batch_size = self.batch_size, num_workers=self.num_workers) def test_dataloader(self) -> DataLoader: - return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers) \ No newline at end of file + return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers) + + + diff --git a/models/AttMIL.py b/models/AttMIL.py new file mode 100644 index 0000000..d1e20eb --- /dev/null +++ b/models/AttMIL.py @@ -0,0 +1,79 @@ +import os +import logging +from argparse import ArgumentParser +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + +import pytorch_lightning as pl +from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR + + +class AttMIL(nn.Module): #gated attention + def __init__(self, n_classes, features=512): + super(AttMIL, self).__init__() + self.L = features + self.D = 128 + self.K = 1 + self.n_classes = n_classes + + # resnet50 = models.resnet50(pretrained=True) + # modules = list(resnet50.children())[:-3] + + # self.resnet_extractor = nn.Sequential( + # *modules, + # nn.AdaptiveAvgPool2d(1), + # View((-1, 1024)), + # nn.Linear(1024, self.L) + # ) + + # self.feature_extractor1 = nn.Sequential( + # nn.Conv2d(3, 20, kernel_size=5), + # nn.ReLU(), + # nn.MaxPool2d(2, stride=2), + # nn.Conv2d(20, 50, kernel_size=5), + # nn.ReLU(), + # nn.MaxPool2d(2, stride=2), + + # # View((-1, 50 * 4 * 4)), + # # nn.Linear(50 * 4 * 4, self.L), + # # nn.ReLU(), + # ) + + # self.feature_extractor_part2 = nn.Sequential( + # nn.Linear(50 * 4 * 4, self.L), + # nn.ReLU(), + # ) + + self.attention_V = nn.Sequential( + nn.Linear(self.L, self.D), + nn.Tanh() + ) + + self.attention_U = nn.Sequential( + nn.Linear(self.L, self.D), + nn.Sigmoid() + ) + + self.attention_weights = nn.Linear(self.D, self.K) + + self.classifier = nn.Sequential( + nn.Linear(self.L * self.K, self.n_classes), + ) + + def forward(self, x): + # H = kwargs['data'].float().squeeze(0) + H = x.float().squeeze(0) + A_V = self.attention_V(H) # NxD + A_U = self.attention_U(H) # NxD + A = self.attention_weights(A_V * A_U) # element wise multiplication # NxK + out_A = A + A = torch.transpose(A, 1, 0) # KxN + A = F.softmax(A, dim=1) # softmax over N + M = torch.mm(A, H) # KxL + logits = self.classifier(M) + + return logits \ No newline at end of file diff --git a/models/TransMIL.py b/models/TransMIL.py index 69089de..ca2a1fb 100755 --- a/models/TransMIL.py +++ b/models/TransMIL.py @@ -44,23 +44,23 @@ class PPEG(nn.Module): class TransMIL(nn.Module): - def __init__(self, n_classes): + def __init__(self, n_classes, in_features, out_features=384): super(TransMIL, self).__init__() - self.pos_layer = PPEG(dim=512) - self._fc1 = nn.Sequential(nn.Linear(512, 512), nn.ReLU()) + self.pos_layer = PPEG(dim=out_features) + self._fc1 = nn.Sequential(nn.Linear(in_features, out_features), nn.GELU()) # self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU()) - self.cls_token = nn.Parameter(torch.randn(1, 1, 512)) + self.cls_token = nn.Parameter(torch.randn(1, 1, out_features)) self.n_classes = n_classes - self.layer1 = TransLayer(dim=512) - self.layer2 = TransLayer(dim=512) - self.norm = nn.LayerNorm(512) - self._fc2 = nn.Linear(512, self.n_classes) + self.layer1 = TransLayer(dim=out_features) + self.layer2 = TransLayer(dim=out_features) + self.norm = nn.LayerNorm(out_features) + self._fc2 = nn.Linear(out_features, self.n_classes) - def forward(self, **kwargs): #, **kwargs + def forward(self, x): #, **kwargs - h = kwargs['data'].float() #[B, n, 1024] - # h = self._fc1(h) #[B, n, 512] + h = x.float() #[B, n, 1024] + h = self._fc1(h) #[B, n, 512] #---->pad H = h.shape[1] @@ -83,22 +83,22 @@ class TransMIL(nn.Module): h = self.layer2(h) #[B, N, 512] #---->cls_token + print(h.shape) #[1, 1025, 512] 1025 = cls_token + 1024 + + # tokens = h h = self.norm(h)[:,0] #---->predict logits = self._fc2(h) #[B, n_classes] - Y_hat = torch.argmax(logits, dim=1) - Y_prob = F.softmax(logits, dim = 1) - results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat} - return results_dict + return logits if __name__ == "__main__": data = torch.randn((1, 6000, 512)).cuda() - model = TransMIL(n_classes=2).cuda() + model = TransMIL(n_classes=2, in_features=512).cuda() print(model.eval()) - results_dict = model(data = data) + results_dict = model(data) print(results_dict) - logits = results_dict['logits'] - Y_prob = results_dict['Y_prob'] - Y_hat = results_dict['Y_hat'] + # logits = results_dict['logits'] + # Y_prob = results_dict['Y_prob'] + # Y_hat = results_dict['Y_hat'] # print(F.sigmoid(logits)) diff --git a/models/__pycache__/AttMIL.cpython-39.pyc b/models/__pycache__/AttMIL.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ee3af4e0d05559e5bc7a4d8ed62481a3254c9f1 GIT binary patch literal 1551 zcmYe~<>g{vU|?u4o0s&IgMr~Oh=Yt-7#J8F7#J9eTNoG^QW#Pga~Pr^G-DKF3PTE0 z4pT036f+}8j5&uTmo<tN%x27Ci(&(-XV2w`;(&^A=5jDGFf!zFMR7w#cyf88c%f{* zDE<_N6qX!;T)`+oMh16=6xI~B7KRkIRE{j6X67j26vki%P4<@{_xWiu-r{pCN-xb# z%_|8=EGkYd(qz2F?O&9VT9lgNl9^nh$#_fFCo?ZKu_!#TD7Uo0IlnkFFV!(GFEueI zGcVmIC>dl9G6wk-#4lE3U|>jP0QoP9DTS?_A&oJGDTTR(Gm1HdC55$xA&RA)frTN8 z737**Y>p)*zMeiW8JHLtl9@sJp%}zwV_;wa(Z$mk7#K<zN*HSxnwb_bEo5M1s9~yM zh-Xe=tYL^}Ndd80Yd~za67~fgCF~10YZ$VaYZ$Uv7BZEvFW_3pun?q%8>=i2RF)U2 zkEMnoixsS%4^2I*2ty4+7TZFmg&@25Qy797G@1QgHZU+SykulxV9*q}#adiikXls4 z%D}*Iiz7ZhGcU6wKK>S?&n-rmTa4bfIP>C@a}tY-Q;Tmg<>lSt3QjF7P0cIGOw75( z1_|z4EFp<`8Mn9-OG-cz`FZhSx7dR-({uAPQ;<bNZwVrll&5B<XOtA*;sP6<nU<Ma zq{)1Xr8qSwt%wiol(f{ulG38o;+2fIIO8F1h>u^%@XJC!BR@A)zcME=Prp3BD7&~I zF*#K~q$n}3*w@oX*CjQzz`ZCjtwcXJKP5G%SRWEydIgn5pp-8FO4~}H7-QsPWP?CC z77+U%8!L<!Vq#$w`Cr8k)d~xgWG0Y4DCT5fU;u@gGswU!1_p*2#sv%u85S_6FfL@Q zWvXFVz*NJukg=AzhG79Sn8i}Vuz;n8xrRlAVIgBJYYl5HTM6p|wi?zHrWD47%(bi~ z>{%RO9&;~KEo%wu0?r!N8m5Izwd^$v3%F|77c$mz)NlkdXtMYffg-Pn0~CXdx0q8h zb2Yhdv8LtZCzjk|FD@)iO|42T5@ujvC=vw`VxRzJEy*uR&bY-{0t%^u{NmJGjBdBs zi}TY;auX|VG3DkKiGz&eERq1JW~{iy=n=*27#{{ALT@oTMzQ9Xmc%>WV)VVmmXn{J zSyEgi12P;;$b$5+r{x!wCl;lEgG2-rPE34^T#Q_de2k!w<zwVw<YKDggN7<pu_j{? zDAQ=Nff8R{Vs2`D{4K8d_}u)I(i{+*CqBNgG%*Jv1CGKXZIETW&@_>pnv)YBkK`07 zkfT8E0VipGQ1pN@ItL@4F<3-X<Q7wY@h$e8{Pgt9y!2ZfiACuJ-~t5{y|=iN^K)`i zlS?x5^NNc=S-D6NWE}@M^z`!bia;ht3B$O0X{C8!#fdq$xDf0jWd;U@C?T+_vdrSl z{5(BKa=gU`r3DKr!Lso=pp1|QatTvGPLx7HPJB{+PD!y|enClQZe~?#QL$c5QG9W7 yMk**oii$v<xy9xS4j-_Gko=B_3l1Ad2-<-XN--!hI2bt?c^G+^K%@|aW(EMPkda0J literal 0 HcmV?d00001 diff --git a/models/__pycache__/TransMIL.cpython-39.pyc b/models/__pycache__/TransMIL.cpython-39.pyc index cb0eb6fd5085feda9604ec0f24420d85eb1d939d..a329e23346d376b150c7ba792b8cd3593a128d95 100644 GIT binary patch delta 1124 zcmZpWS}4hz$ji&cz`($;|LDpj_Kmy`85yG{zh^wd$UAutQ?Wch0|P@59|Hq}CR33B z0|Ub?#)={lkQkT{o$SY)tjP=Fi-RP2<KvTa5{rwIQ*(0S<B=2!fs}}DzQxSLD8K_^ z%Yg{6M(N4lSd^qdW*2dQ1i?xnrh~Y=lZ{xPFe*$IXVc<V1qrBu2=&Q6Y-(&^O=6QP z*h~b#8o-1oNM3yM9yTLa2@q2fL`Z=Muyr7J7RiFRVAWs(Y$aG9$XOuCVipDl1`b9( zwj#O7ADKlayR*O1(Ew>u00o6;5!esrAOWzUAZ@oe3-XKOa}q04i;AQt$8$(3ySw^? z-eQSQOE$d48IV|%n44OXT2!P93Mkg1#JrTeBGbvUIaG9HK<0qL7_8U`8q!7fAYri6 zz>WmFk9YEG4x>;eP<%l#DDK!97#N%x7#NB_FfcIGFfL$N$WY5v!?1v{hG`*VEmI9s z7E=nNBttE833Ca{0@f0i1#C6U3mF$OHZvkbCg*TUGqY!LOzz^8t!GSO>ScoJWR_&8 zWv^k#;#|O0!=A#D!m^Ng0rx_NT8<L-ERGtE8m2U+6qtI31w2qY85Z!?Fx9Zsur6c< zyM+%f&JSXxFr+ZFGNmx5F-bz)DgY9#<*eZhX3%6+^(z7ebP+hn85tNDUjG08|G%c- zE!MQ0{KS$X8&J?_PVVKB<OT(2ksgR|I(Z|Pc)bOP=?Ee~AzI`FiZ3p3(k&7Lg*;<M zkpTk(LljfIM-)?hcobJ+N=kfAYF>ItMv)On0atQPaePUBc4}VnEw-He^vsfCq=XEP z1yC9SIiVO7?+i>Fj9koIj9QF5j2g`KT#O=&Ld;yue9T;IV$58ORRSSJiFw7oo<4eM z`9<Z4MJYvQAlHH86ik4l7;G*mL4mviO?^DPMIez|9P#nFiJ5uv@tVxwV96{HXJ9A- zrO+ZakoOqDNkx<K7E4N^W=Ud^FUT8gk?{pZ`AN4}BjYm?OEe`R>5ii$zbH9FFE6i1 zda@*s3ePP@w_B`vrMU%_MQ)S5c{HU!J_6fX1kOBmAg6kP2>;1_JmPv$JVmL+r8y<V b@hO?fC19J10zlG%AOh?i1R*;46ptAIoyY3q delta 1235 zcmZ1|*&xN6$ji&cz`(#@|1>|Ta3k+SM#fu{-!q<J<eR*QsaRfsfq|ijpMilvlc`9M zfq~%`V?~iDNDNGfP4;6>*5m{6B|wtA@$tzyiN(dqsX00E@kk1VK}y6n-(u!r6yOE1 z<v|2kqs-)QEJ{)!vx_)Df?%Z((?MLm$wsVC7!@arvuSawfdte+gvMkaHZ?Y|Ch^G? zY$gI=4PZhHBrh>}51WyzB#0>mBBVhC*gB9qi{wCDuxc;?wi2ul<SbbR28Lo51_lNW zMn1M8`N<!dMJBtmztPbIX;K6Qg=i7j4;CN+u%RGrw>S&(i{o<=D^rV#WG2URNGb=V z`h-TY#HS@2-r@{MEK1BxElDjZ(gOt)Yf)leN?wuK<k=i5Iv@+dVSEdu*a#ZZMfM=$ z!A=7^66`*{$*(z#av4GK1;rqjurn|)c!OLm#K^!<!<fP-$*_Q7Aww-w4O13l3X>#5 zEprJ|3G)J$66OW0HOvbc7cw?8Vu~;^)UwpF*09vDiZIkLOEA>3m9S;8PoBUjThElj z+{*+tk42K9mc51{i(>(24SNb}3hP4V1zZamYB@lvYdC6{(wI_U>KPVr*Dx()tYxZU zSin=m1TuReGnmH<7v}@9QW#PgTA5N<(m<{SspbcX)^e7xWeL=9rb7+nDiO>Qs^Nt4 zxNA6TxNEpl*yb?T^3?DIGib8=-C{{eEJ-W^MN<(t!a!m5^8f$;|23I!vE}5aXO<M- zVvCF~D9TT|#Tpr(kyxTBe2X<LCqJ>INNaK+mt-s`f{OG(fo}$iDz?;$g2cR(TP(?? zDTzf6AaP3&;Rqr?Iitu4tSqr8JvXu97NgrO_Tv1slH9}!O<qWrVaqO0EJ`oF#h6iK zIC%q?M7=Rc16OiRaePUBc4}U6kt@h7cM#zQlHw^!EiTO|DUMIcOfEsnMc^a=O29Ut z$W{WS045$5E=Daz9!3pjE=CbXA!aUSK4vavF=j4CK1M!99>ywxkfOxAVqZ@my|nzI z^2DN)B3pBiP2fZYCctS7Y&IxGgS-z-{ye-zAdy=f@$tEdnR)T?n#@Itle4*{)r*`# zMu43Sa&ZyJ?V6I1w9QeHUzD7omzP&0GkG(&3SUtG$WN?!rMU%_MP8F1b8Ciy@(bAL zB5)S90~zKAB0$F8V$ID@NzJ*%T2PdkS8|IbwJb5GC=?`%Y9CU-fW09$Ih)4}0Bh9^ Ae*gdg diff --git a/models/__pycache__/model_interface.cpython-39.pyc b/models/__pycache__/model_interface.cpython-39.pyc index 0bf337675c76de4097e7f7723b99de0120f9594c..e9d22d7ddccbb2a3d2ab59b98e066f9ac5d42285 100644 GIT binary patch literal 14054 zcmYe~<>g{vU|@JOWkZs2C<DV|5C<8vFfcGUFfcF_w=ps>q%fo~<}gG-XvQceFrPV! z8BDW8v4CmTC{{4d7R3&xIifgI7*d#WxN^CpxEVounR9q@d87EiY{ne^DE<_N6qX!; zT)`+ouo!EOP_A&4FqqAjBa$l`B?@M<=ZNKsM~Q>k961uXl2MXiHfN4Ru2hs%u5^@i zu1u5+Sd1}8HcA$3zg(_-lsr^SAy=M>fsr9sF-i$4q8z21!jQt1qmru{r3x0~&XLVk zi&6u#8FSR5G*Wm{cyly!wW74ZVthH;xjIogU^ah_ZmwRG9+)kVqn~RKWsqwaWeDaA z<{0G~M;YguM45p3LOG_nW>IEfws4Mlu0@mum@SfHnQIkg#mL~!kRqBQ*20h?ma36u z-OL<i<Ia#Go+8o0kRp*PpJm(39A%fn7|ft4`4SXEewvK8SOOADGBg=)u@zJn<m8uV zGTvhK%uP&B)nvRSkX)3SSdto_Ur>^nn^~1wq{(=T8!DEQUtFxocuP3BD8INkJ~gkT zD8HaGz9hdW8DyBGV{v6}ZfZ$UX0lIyadB{FUV2WdPhwJPjwbUhDKH~AH$T55BQr1E z8DxSh#DH63!TD(=A&EulsU;}l{9eT=$O0gj6qh8H#1|*$7o~z+!U}ffOGX9;22I9W zoMo9M@x{4`IXNJ&<>sfP=71Ej6{QyErIu(i-4byuEy;IFO)M!bN_DL$DN0N($uGLa z5tbR^3sS7fbc@}+C^5y^(f1aQYi4?C9+aoacuTN2BfmU8IWadrKQBHL8dhS?If=!^ znQ57+MgFBF1*K3=Xfod7@ky*qEdpy#%uNObJ2Ga5a*Ep+7#LC+q8L*cqL|tl(il@1 zQ<z#fqL@>dQ&?IUqF7Q`Q`lM<qF7VdQ#e`}qS#V6Q@C0fqS#ZoQ+Qe!qBv4`Q}|jK zqBv9dQv_NVqPS87Q-oR=qPSCpQ$$)AqIgo|Qp8fkTNtBwQzTL(TNtAFQlwI(TNtAF zQ)E(PTNt7Q+8J0Fq6C8(H05vcfdj!auOzi7EipMY8I*;g4rLHvU|`^7U|`?|WoZ@` z28I%b62=;aW~K#93mF(0Y8VzULunQeUBX(!)XZ4J5YJY_T*DC0Uc;EhQNs|=QNmfn zkj2r=$jDH`lENs-P{R<<Rl;4vkj2%^n8lOAB*~D%oXu3!RKkl8ZDy=xO<}2FO<}TO zNMT*TSHrpxWHMI@TMBy%LkT~`L>C6IeF7j}3P(1}1jZtf62UAXFrPDptCy*UAzr9N zaDi}%V3r72lp7>kBA6u#=JE70mI!8vf!VxJ8SxZeFrTlNsYEbK0);1u;H9%IWUOV& z3jx`e!U2jSRNGQGQs8!_aHK$ON&&?l7uaO+6y6k$6u2!Ab~-D_)DpogDX<%OdO^OH zE)mR<LE_8Suw}`CVjnCk59SG`2!MGCH7qGYk_=gjH4O1eDZ(`j@yazUDI(xdS4k1A zVTf0S@KeOHnI<q6eJD{|pq?TQ4qc5BO*mU?0%P70s7i@!rU^_%c_rFeIw_JVQoW3| z>?OQ4><e@kGBh)$Fr-MgGD$Mja+K(0>DO@dG1PL_a4s-d$l$^d%TvphB2&W^Z#aRe z(4a(bfqo5ViQWRE6xoH0CB`Xo&5R{_3rs++%2Hj(RHC=QbRok+)*6O*knL_z+vUNw zKPl19GK1RAUBX+#4YyqZ&32v|9xS%61KZAnWV<5NcBL8yknIauK=$TcfZD2@%`}0z z$f3kM%OcA%MI}WQY&UZ)Zw)Vo>CA;7b4^mzASTuDW?98sgYDym`U#W{wem`!W~+nE z-jX6y%Lh_bqL*cp!jz(sqM4%A%Uq(jz;+=+iC&gn3P%b@igqs(SY9JV2dBJo3YuRa zW-R0c`KX35%RWUHY~zy}hAf9H$1JCG##;UoXArN3KSi>JA<HF2uNUNh*E~=<bI5W- zRgt3K3(C>1DGb33ng)Kim~-;eUxLcBD3+4s+@dHh5E);VS(KRbi%~xrR5HWL5GDo& z237_JhG0+`BF4bLkj_xU5X)B!N;eECjB}VKG8M7}GpuAP0_oCZy2YBDUz%5Pi#<8N zs3<kLWF^B&rdv#U2Dcb9ixe0b7{J6YNBxZa+*JL_oWwl+^8BLg;)2BFRQ-^m#Jpl( zPaj>E)YJm^qQtZkeMo%(q2e=PrJP<t<t;9ooXp~q<ow(MJ68q<hR+}eS81b`hI;Yw znR%Hd@$q^#Ir+(nImLE*2wg8hUc1Fno|>7SQIeXX$#jc5uQWG48B{x@7T;n|EK4m) zOi#VVnwyzil&{HHq{+a*aEmoJ5yWFE(gJw~EW}clm=h1Gyb6j^(=sb=v6iG3m&D&< zNl7e81c%Bkj-<ro?4<m>)RzoQ44@h!GcUhN1XNxp<>zPXr)1{k>y~AfBpI6B;(*k9 zh87{3jJLQ7ic(985;OBsQ(pf6|Np;R5g(}1WGXTRS#AV!4^+LW!OOq@|Nqy7NWElW zVPGh-0Lk;Drln;jXQt+r<fWEWsrZF?x_CP3g96wmHL)l!GcVn>A~ClhC$-oL+$6{- zO^TAuE6YsDOpHg?8eg87l3JV^pJb2;N)NZ#iZemFUWza<FuVjM5Kujzom!NaniF4~ zSp{m8uoagSWu~OQ1i3Sc0~WcOEVo#}v3iTQ1l)Rv&jHoC#kV+0i&Nv1OH&eW2^OUm zXCxM+#+QH+L|T4Pt|oI4s9jp*2=ccRC=A?hu@;vWq!txPf_%vVi|AV-#ffF9@fno` zsYL~eMTxnoC8<Tlx43fh6I0^B))ujWG_ZhLMYlM?EfG+A=@w@mJmy&oiZb&`ZgGN5 zDo8BJC<3Px)`G;MlHyw|pk#N8B`ZHO?-pxOVqSV`kpaj=9&qqMT2jTgSRF%y{GD$J z_<~zj&iOexsmUdo`FXcE9FvnvixQJ7Z*e*2XQbwNCl(YW-ePw%3<foiZm|WWCMV|P z+~N!<N=?oz2I=4qE=Wzz1Vv&=<t-7oI7FdyX;E3~E#`nA=Ue>F`FUxjAT7R$B}JJP zw*-=50`a*J5!U3K{Jhj#yk&_wnJMuwxAK6~E0h<-14_2>#Tlh(X*sF4*rALfIZ((O zf<2mFS`wcIX^#~b34w(8(~9zQ;}MDVmLOC<J{8trD-s4N<N!xOd|F8nsBKbY3>Hhy z$jnJ8O3k~)l$UplD>${VG&QdzGco5D7g7xJ7Nr)JW)`Iu$EO!1rrZ*AOi3&#$t+8C zEK3i_&(ASRxy2HenOc5}%_lQ2HL>Uxi@U2&XptGn0#<O$++xlsO}fPbN(i^Oz?mG} z1iZzRmVArNIX|xqWJFM^Pv|X9-^2=-nR$u1so<n{iw&Gijc&2z<rn1^fs%cZGstx2 z<iy-4!NighkfZbS;vwZpv8KQ+mg3Z$G_d9(4^ZfNf|RpnCl;j_-(o3A%`Lda4h?ij zK(OcLr<CTT7Durbr4}1n+yW)I#9Xi+z(S@5MV1T<3{gU8ft#6^4$c@W8E<jMLsD^k zJh+@KQUTcuDjxMf#rG3X8Og}U$ic|L$i=}1A{n_DMHmH`I2d`D#h7^*g_w94nf|jd z@i20LX*Ncd|13;M<aZ8U0VXc692X-Sm_{OhvT+J92{7_8@-g!;vN3Wou`%*7@-Tu# zco;>PSwQNUc^J7ESs0lZIT+a(c^KJ#bFuR=@i2*i#Qw9O%KqVC=VKIN<YDAu5(3FH zvi;*=;bUZBWMSk3>xPhEbH20j@-Xr+NicIUR*B=N29iM;859U0466S?7!-@eQVa~B z_Ao;Y<3h#-j46x@8Ectp7#1+qFf9ai!kGNP1-T}3Q7R}BIZN}33rkZ|s~~AZlewrG zB+OWGixm_%#gM=US9hQoxW%59UsRr0l(G#J3*Z(ylK^9t7@8}giclO0Dp^4qz>c&5 zIg)V!1IT#`7#A|sGS)EEFg7#RGL<l8G1oAqFr_f}GSxDduw=2qc`P+dH7qsEY0SY4 znk*1+fNPPJjJKFmGIKTAZm~oCa*H)BCqJ>IC?Dh(w#1_J+{B7ojBdBsi}TY;auX{w zSs~tJ&CDw(ExE;(lb@bhQhbXwGCm`*<Q7|Gd_hru61vw}ic3-pc7yy49(rKnW8`4s zW8`D15=QeqNL?~0K0t8;!k}IYI|Bm)*z0{@ug5UgGS{-yvevLHU|7gd%TmLzfU$;o zA!7`4En6*nEeANp*lJi7GSzaHFfU-KVOhvn%T>dW#ahE9$&kXx${@**!eq|C%wWR+ z7i9*Cmax@uG&9z4Wg}$SQ&?&^Q&_=}t(TRNp@y@Dy@sQPC53$sOE7~bhaaR+1XamJ z@}P_c%H<&|nQw6v6y>MKCnXl$;>gd7Pc6t#&H#s2dTNm-3#2TV2vR=*6ws_W;LbKj zQ5Q&P5{Lkq1CE8FZV(q-y)b9yl@x(esU|lx7LpQ6k~3~`f~ok-l!_uy7!^$inaLPm z)DO}Lw*_qJEyhUn*x?1YE;94d<3YiH6ckOM0-b?{i;;s#fSHGri%E=8j!}fMN(L>u zP&A?>SxyE91`b%f{b2z0Izc%qg{g(1h7r``1Lds}h6PM1%qU{aB@7D~Q&>>MSV|bO zm}|g2>RzT=mJ$|FtTi*bFvR-RveqzUvDL6jGNiCeGSsrcc^r}qwd^%)DV(4_X$nIv zM;=cJdkuRtV-1@$sF%Z$!nJ_2hP8&HhP{SuAq%KqTNskU4YC(uKDb{BGLtKXM~ne% z){Pp5EN)O{0`=c}nHfPTh$n@=hG79ONOS?;LWYG*DFR?92ud&D*0zveGN@q;O3)w- zDu6)w8Qh!(RU+x2<}^<&V+~^pgCqmAG0j*6jyX^PvXZGt2^5trAOa)+i9%4ZQl!ej z!0-?xuExN?P-TiMhQJ9hJ~b}|TPv6>*MSl)#2xcN)?;zjBak@AS!9_HO2qih2UTT7 zAoIaxSy4O4I&c#OoDo4G{!1QQ$tzUXS}7zZCo5Ff>L@__Bns8F`qi}<E(4Xd;CR8^ z%_))vSp@d4BAf;0!XxAzNF3y{m!Jag7ArWTtYo~!T3nKto?68j0&36eCzd9M6io&- zWck54{T6d_VhLDHayBH_2Y}4+0uiA64=EkIL7AHk)F4hQxy74YT2z#pR{||yL~`=e z<I!7u;1Z)K2xLkp$TCiF*%2iQb`&hz$LE8G!HRFOf(r2BTb#Lx72vYG_!b+aiGGVQ z8B*4PLiZMPVsbJhcHln0B?LDSp4#q%N)u3=FffTQiZF?Qngw9Y!N|tM!zjim#=^%a zz*MD3Fx{e53!qR1#T2+&$Rny+C}CQ_46X`516eFJoLH&?h$y%!sNq<^x{#rU6J8at zL5qDBaIw#dUhHeKpcl5F<g3XBi8643o)3xuL|M*~S(2Ko$pwy!qM0DMSs(&leizLK zanVZfIUq4`fP)E8db`D625QJBf;tx90{lBD$oatK47dPi0T<s)RdQ%SjHV5x_y&b9 zy!h^61Qp*|kReWJ@eLa0>}9HD0vFe;pkfnVtb>YhrW7_5F&3~ob`&vIaIwys!jZz# z%T&u+!cxNu3iD>h6pn1>qA8%FJcYA{VF7y$a}DbP4sgNBl>&m?pn^3IG!zILR^<VU zaixGo!Mb=$IBS?{Seu#9WVmV=7I4?FfXu34T?moogB9W+TfoC{%!~{rpus~va1|2( zu2lGI*d!TJ1i_UGLl%T51R5@@VFwL+PGBt3NMWeu%;PB$05>7nr5S2DO9WGd7YNm` z)o_BU8E7T*BSi!h3J4p(l?-SoR5*nTJeb!}!;mGC!XXYCVihS7T>u)`6iE>kVOYor znmq_+&=e!1I?-f=Bv5d=!Bt>$GBPlL3)Cue+=&IAx3SeE;1)8tB*rLN3D$T;^Wdo) zmjk(B4kW|$1q4m!p}_Qo1Wo6InGUY4qS!$LM(`RS9h6@|IEn)#4jP!yWGb2pGIScK zyk{;;%qdz5;(`iONK11O0|Ns_IDkq;%Ihx!xLfT(HHJ92dMa8DvVR4L02O1mm~&F| zZn2f-WfqpEqSY})pf>w04p42Bl9^mm6c5t08bsi#(TY}rWD`LI$cMMsiW2iu@^f#o z6sM-9K#R#J0Z`b08mX|_3X-PRfOM?|5$iz2dJq9J1QJXeKwNOL2NU2jQv|8-gQsT> zMh1prP(91QC;_XjSOplBm_%5lScMo_{<E<Nu<(IvFD_W^#mvLV!&s$EL~cQ;zXTZ= z7(hV;uD?{k<twOp49ej(ppFczgTn;w;4m*>PGMTe2%4~|Wv*eUVQywzz_O5`maT+g z0V}vpV=n=9Z`h!vGDitp7O3&d3@%wYOE^Fc-*j*%XaOgvtF(X%ZXRd~hr5QmhAV}& zm#LPggdvNk1k{-2S;&~e2G-A8!dJr!Y9NE?EPgPHAxof!7vicE_7tWRmR{x<rdqyQ z{u<^Qen|$f>jV(4t6|6ztYMdANZ|l=vr?GhoiB(eCrGr0V}Z~@h8lLH8W&V<vedBF z2&8b$VGCx^<n{}JlrNy#4P3}*GTq_<cSPe$QuB)Qi*B()+A76G>Y!K$^_3BA7ELxt zasfB1!L3(N9bHro%3qu%`32eWpiwVPVQ^x<B~YAN0v=|HPt8kA%1KRuROOu@)odxL zC5g!ykTzNm$fjNpAr2xyV*^FaAXXJfDNAl@V%{ybkW`TKiok6*R`9syEf$cKMWFOn zv<Z~S1ksxUU@5p0A#MjZ0Kjes`vR2Squ4>tj!!Fz;!Q6~tSnAW%t-}Jciv(vF3rtN zO}WJdvM9bJGbgq977v)6nU|88oLXG87vxNEt^gC@ticH$2n1y^b5MTJ02NA%T#Q_d z9L#)-JWO1m^v}Y@B*MbSDgqunl49f$;bIhHlw+!rK+6Ykqfi<b($E=6P(Q!;5EH0z z0ZRO!Y7yRu05vSg=o^9(Bdmwn$(X`{FtZahlEVV-qp^aT5}=;r0=61vP?G~ZR=`%n z3hFP~FqE)^X6c$4Q@ErVYC#!`vxaE_YYO*5#uT0u?i4m~Qs=1UECG!vaDy5N6BuLp zIvGlMQg}NUnwc0G+8Npz)0k5DKr_NMoS@!rEmu23J5w4{3O|w<Xa<a<hO36Vnc0OQ z_EZdWEl(|PEnf*w4Q~xk2ZJQotQy`FZb^n({u=fYz7B?FMrno`erbkUff|7<erPX~ ziIJg%t3;rNubC0#n%J6J!5Y3AP-8`?MreWHLIzOAS|FUlnj)0Ky^y(9s0Ji1vXDW7 zp+;zd=t2+)=1DW82#YX)<QIsQh-XPGWULh`kz62EBeamQR=7rZfpiV~Y=#t(4h9fe z!w)6}XEUUT&gB3VT@x4!{U$ILvJ@VuVOt;ro;%cuVX76W6)oYZ5t+@9B32_hn_(_9 zs0gVMg}JPTtwzL#VFF{}EkuaM2GxR^J+d`ok_;*0k_@%tH4Ir0o`fVrtwfD@3O9I0 z(J6(YRx*#LM6O1nnXyJ(nxR&*geOIEfqacvjbx2PjX0#)HG!#6C`AgC8W47X8(y+C z3|R`GRFNW`A_Hm)mMAVzs$ok9jr~d0NM$L*Q-iEDLy8=Xw?Jhf0}n%rJV*z)hp#|J zldOnvd#4aq4_6st%UbZN3tLkRqiI060Y<oq1S)^QEkrU*C)`9Tfq56)L;{!aq}!v( zSOhB4A)QlP4aqW?O=OvmC^vDLUk)=riUV8-CYB~g@#p3jmw+ZnKuw~|vQ$u5G8BR8 zs48~wFol%@hHF60Inbm@u>@?^RiKu!gRzDYR6f9liXdJE7vZ1*2~DP3jCr>>3KB~| zGk(x<Xi(V>9xG~xxr|D76`cT8BA`)2h%?|8b-*kN0j<s{NCXX26yM@Vt;j4ciO<gl z*U&HjgEY&4My`0km2Glju3lC_dKI&RtpcLvE}~&mcO%T1py`<65^y8CiXR+lMXAN5 zIVF0@#bvh`6E&G`v4AG7Zn0&gCZ?noX{v(T322RC#4t-y0?2Eip}ry?5bH3AK(v0r zEdl1L%z|6Y#U(|zSaY*eGK-3~f|?A>$pxjiSo2DA3o37Mf@ePq@{3b%G3RFH-C_m} z$Fn6CmsA#{-eSrtxW!tUnO9;_1j^gDIMa&qa}$e-5-V@9faWc4v4X}JiXm+b@W?W_ zkqv5x-(o6Axy9*{Sd!=lnjgKzR+1kN@+TWOdWtori$IeZQGAe5p?FA06c?ekv~Mvd zmL?a~f!qvfHGt<y!5%#Znxo<cH#T7DFiHrS3u%uOfkrfMF&5lnEQw-=DvV+WSqhqQ zRR)>N0&2M2V$RG>zr~WBT3LLHtt>I8G_|-0OEdo#Xu()%P6;?AYBEET=1EY%gVQG{ zY2M-kM=iYMEC4lfKx1PJj7S4iOahDoEF7$Sj8cpO;Mp+|MltZ58E6(ufRTeyj8OqR zBc&rCz$nBh0-DWW<YR<%3nUoDcm$Yu7^`IQmf0Bj57Zt4cL+eE?<Js-WyTtYET(2A z(7<po!%9X!P39s{flvgRL<G+%YBGW22_gV40Kf#;qavVjz|8d0BBT|8#m7M21tk^+ zCO*u53pO1uK_;TCC<8eGJ~#v#oCFP9r!cfKr7(g?1`wIXEXlBdp$0UM#<Y;JhN*@b z%w}H5xR9lmrG^DO^3Sr6X#q$V^Fk(&STKVos~@E3X9PtWxD782Y8x}(VuvmVxWxim zSWt8gBnq91VuQB*88dD%mKVJS=>)kMoCqM3pP*1Z4e~i?+Mj{Zh*5;8N&zi$pynf% z8=%AwD9AvXK_(R!fd(8IvKX=$vzTfbQy6O)%NUCEQkWJ%SO3*8E?@yo&w*Ngpm~lw zmKw%drYflt))KY_>?y1@jI$Z$f<vX5sg^mMAy0&n0R&4pQrMaqn;0V*@|X-jGm+rV zS`9-yBe;#h?gt(dyoK562e}U1=|`PR0L_m<CKJG60a{CvoLG=ql9&Tp?EtDOz~T@B zlxDyc#C2GKSfz^=M4)MN$odIv?TsqX)OBi3v0jzFXCAnJ4GQAYoKyw3%$!sOzr@^B zh5RA~n4B|ci49m(QIqQydj@!#zZjgVA@#REDEG4#7vyA?++t5pElDgXDZ0g(nOl%w zRFadKbc+|l1UaZQC-oLr7^qlqEh@?{y2YEBS6q^qmz)XqsU`~~lYm+Ux44oMb5nET zK~ws-Si$~-cpsc%z+vFdz`$@3lvqI35(6U_3kRbN6AL2~6B{GPe<ntz{~XX6?J7yI ze)zIWJw*7R^mjlVeNYU6iu~dv=v-J0V;0jw#w=#gx>4}jH|83KES4<R1#F-$e<njM zYYnJZAXm#)!&bvm#!#dMZhL|zb=jJkKz$V(hC(gSz$j>TnvtP~F^wUUp@ww{V;^X) znz`zff+unyC?qO?YOfrH(&Eg#bcK?PRLDd~aY24w3P=PT{@{f&Tnb14RHZ0@)-dQn zV@LtqddbgAEUMH~h|DilNKVXCNX#kDR{(d!6%z9*6+o*SP)sSwSI93*Eh;a{EJ?*; zH@HMs$jn2rNDpoTxJdBR<h;e6nO9tpnp|>=3zQy;(m|Xe(8`8eEFde3Kttb<6a-o? zUGx)_g4jw6QW8s2HQ6DF0My{P#g$)@ky;cF(gR%v0%5ad=7FmiR*;P0E#}0cboA8m zmw|!dF(`F_)~+zHiZSvqvVhl8h%j<tPLQIcLzF@klqf(2DY!rZb?9rL#pObV8YWP2 zSi)2To|Y_O&H@#etP2@HNdue+IGPz@bP8iHb1iELXBHP|nJswyhj{__LeNwwa|&xV zQ&CwBLl#d8n<PUP?*hIW=GhD>>~ooFK$DvM3qjMDY+0a%x1dfO$6O}J0$AuMGq)dT zAzx-*N@@jYAzw~uZeDRn6h|;90*VrIK$3}hdHE%`Sdw!S3yPM2%3V;gT_vawUxcIp zS#wmyqo14`52|?e3i8r3CoqB+eKV&NWNHc(#ey_;f(TFsDFV$H6+HlPuYicFAOhZO zy2V_aSA2^-BQ>!kH?bfJlnmn2z;)a$_MH6m^i<F?uhfcy)S}GX)Vz{gT!|?uPz6P+ zK&gQp)GkWPOb55gU}*v|bhr$x22>5^CYD68rliFu=SH!lWu`-y(jZC+P$TLiC>?;t zq8J!KZ2}c05mpWsIVOzZK+rHK$V~9MEa);UO{SvvAg_T4O`%&H@!%pUKK>S0JhViM zkH5teA75CSm;)-T;^V;yy$H1Kya+s>G#8|OBZ$}xB6fj@;~?S<hyc$(f)ZB|s5~nI zCEFs9n{V+#R|h4h=H$Ru2!XxT2+{^F0YR(jv_Z2-pg~j)W}bQmP7YQME)FRU2@XLH zeGW^GFb-BOHVz>UKJdyr+^gh4#^GBf4_?F03aa;uazLZmpyo1YzzKvwY!C)Di;6)N zTM0um18BtysIaSHh-U&-ibZch4FTpN7m(vYB|*_4kN~t<1Y#F~{RQ?R$Ti?lWMg7r zC<YbD44_QQ3tBS<UnPoCbc5^$xd9ZA#h@Ht!<YqbnDsH#GSx5zGiWka#lp*}Ajpcv zVsHruEe(^C^NYacYe_y>7(Di;07``^paKt64wq%7mMeh9@Dz$mib|6~TO1H&pC;2S z7LWpnzaVRjz>7R?G3S>;8oH1o3hX&h>fnNVj*W-0iU;gD=&B-3wjxnb?Jfo)#6bi& z$UwS^BtTqnFo6khsDaG{MGPy<Ob$jK7I0K)+TUU>t}HGB=g(W9&}RS{@C(G+0m?;e z1&MhniN)aj9K~LonwXSdlvf0r1i1xj>1X6*CP9`|a^)tLfVQ%L1VG_^i@6}D<Q8*s znb9rg08gJHP&`F(`&RnHwzYuUz)@_zm7u*Ukd>U=r6rj;#d=^8JogaAS(2HXs|T7# zDFU@si#S0J15IldftD13yKc8Qz-t2a^74v6(=SoNFs@!&X&!i?CTRH~f?Wh2=?w(+ zY{bDzp$qqrlo#=UtO0GexFuLn304%J16o=PT4Q^QsUYW;P#JhZC!(2|T2#aj3NLQ3 zYKUwRD5-(^OGU{U5D`6a<y~9^nmD+{1MU=pRsf_HK^9+v7gk0Iz^nwVjsy(>Lslq? zf?O<vE(ZxLJ@AH@BG6)oC>0z^z*fh@2G&7?+)>=gi3KI4xq6^BSrKSx;ugOT%$vU8 z1&rW=;T9W|14=`XWC~7PU;>o5ZgJS;f|c8WYK~$z(5fMDZ3<eQ2*RN070|jw5Dmg2 h%p8n7j6#ecl7|&cg4Px?2{3}z{qZmf)iZ!#EC3nAaHaqN delta 6350 zcmaEsJ1K%Mk(ZZ?fq{WR?9h}X8;6N}GK_W;wZrOJa`<xvq68QjQW#QLa|CmRqJ+R~ zwjAMHkth)`n>|M~S1d{l%;v}u%@vOl&y|Rh$d!ze1dB1|NJUAdFr+Z%NaxB#$w0+q zb7hzq7#VWqqU50>3Q-Cv3@MyBQn`vziV(5-9Hl7b6s{ER9F<(vC{?f^PmWrydXzet z&6}f<s~M#UX7lA}<!VQ1=juf1fcgA6y19B$dW;P23@HLBf-MXwf~m?``pwKy2JQ?g zLMg&63@O5?GFgVr%uz-ujKK_=A}>K9;HSxWizOhjBtw(&7CT!(WkF7U$>i&d3n#ZT z`A?Q()@00?oM$9IIi1<jRy0Mdg&~T!oq>fRiZ7T!Q~Va6Z+=Q@j%Qv;YEfEZa%wUQ z$V4dSV_;z5Wnf_7W?*0_PGO$>pV^|mh9!kjlA(qno->6hg}H^HgsX-ji=&y*g`t_T zh9RCC#7kkxW+}Q*!jr`d=CP)*^)l5k#PgQ$EZ{5Q$>Il#vV%lRc(Me*JdR$*5}qtU zFq;!9Bb34k=5zHjmGEQ<gL(BFDG;s*m<#5mvn^z-Wz7o#S(w6-f?-_>OA6ew6qXdI zRVgefNEW5AK+P3S;e@c$SwW_j@MMXC9l_BH@}5|I4QrNo3U>+*m@5J1@j`i$H7qH7 zk_=f=H4O36Df~4I@iH|mDFWcol1&k;VThN5@Kc1cnI<q6%_)&zppYWG5M;1oi4vTx zJb^Lq2vns=Hq!*ABB>IUEY%dz6tUiV##*)#wHmes>I)f~8B-Wi#9Nsp8EV-}G_o{n z*!vi2Ichi-Xf0%LVTk3a<xG*N;f&Xwz*u;vL}P(w4M&N_0-Y4eg^VS-DN@ahB^nF# zK+ejNTgX(Ru|R(z!$Q^?hIo+mZcyu`!PZYGQOPm@ThCZu%T=OQ!v(iq2F-fz8g2~h znF^!8)^j6SFAKF^u7&|*{X!Oy#d#N?*2;seeN$qXWt3%{qL88pww$?^r>33<i}5?a zK2d@gRl}2I5^oB&kO%53P~ub0D}frWoXs?WxhNq;qLvq=szf8pEJY<nHAO8&y_dN} zV}bcXh7ye|ixid=mK2R%Ca}C}iY7$f28#mS6g1yJ)GXu#`6-1Vm_bv^?-p}Te)?oS zQIpBh9I9**3=9lK5|i6Ff*6e_-{eRWQDk6XxW$^Am|T>v$yB5Sl2)7?%xS`SYjOvt zH7^eX1A``0kse4~Yw`t7QEm|PB?AisLy_U+51c71%nS?+lY_Z5xHT<ru@;vWq!twk zPrk-wz{on8om)YOBPTySz96wA;}&aCVqSV`k;ddm?ySl8xZ@<nK!$39?B~fZEs0M{ zO)M!bN-ZuDm>kNZAtwlu;K&7)5%FmyMW7O)NCzawl9O3nGI=UbETjJ9&pdY8ntZod zic@paZn1(D7ukV~wFjwW%}y*zFTTZ|n_5zonOt0?GkG?zbi6(T149&RQEIWNL6H%N zEd;eWKDDBxC^5MtGcSE5<1Nnk_~e|#;^O%DB5<I|fQ+?fU|>)Ml}YD71r;M7BL^b~ zBNr<hh-4OH=3*3LoUFp9%*Z&|gHLvHI$wExl{jiKrxzcenU`4-AD;{gIFKnI3@Ynb z85kIxL1tfKU|^_WjA5>2u4SoZtzlWfu#lmarG|L{V+!L!Mlg@Lh9Qe7g(;h<$f<@Q zi#d%Ulc9!r31c5)En5v+7E3K#9#;*+0_Ga#`h{Sj1*{;gwQMyES!^|IDa^f$wd^(Q zDJ+tpY|Ea)D#=jGS;L;fCdp6>(#T%JQNyy3sg}EhV*zIk%R<Ilo*ITMt{NUmh7@*I z21$k#4s!-(1{(&rC?`m?gu8~TnX!f^n+Yn*T~p6l!&SqQ!Zn8_m_d`<uZR&8irNee z3@aInG(d60mXn{JSyFt9Ei%5KC_m{IYh-*zVo4D=ZGaMAkvJ#?vOr?B2&B6hq>VKv zF)1}iljjy|T26jq$u0Kc!qU{#s?;J-dMXM9DG|&`EH0^!&&W(kNzIEdE=ep&y~SCY z2UC~{G7Fqm5<sfi5{uGv6DuH&V$RGf(c~@407<YWC6**-+~NdN@tG+QB}FA5ImY-~ z9MHs5e2b?jwYW5=q&Pk$Gr6Rw5M+E2Llgr;5y-({_upcSgrp^Sg1W_9Qk0mPmzkFy zUtE$}Fbx#_pghRH%*DvT$j8jZD8R_U#Kp+P%)<mqY8=cwoLo#|jB-p>GL!EMOKc7m z3S-m-2Mfr|l}traV0TrZ*;6FTz`(E{<RrPtw}kUGz)k|`C~5`?f>~e!oXSCp4}iqw zCKrlW*nzwXF%OsGLtw=q0u*H{8E-MCWadKRP<--JkvI)ekQ~?oNjMA4h1+)qBrZ2O zTU2Fos-V0Pe-SuVn3EGrZm|}ZBqnDURf1GFfd~f>0ZKhZr6Aw378GUXl}tV=swX4@ zl5qhMWuRokS(GuERm{Q&oSHxh=@xThax%m(aO-ag!Q&vcAU`=HJ~c08@8o>3uzHYZ z;JI=MQMqyfQw=jT%Q3^V97_#b3KKNTF+;N)DC>c;94PB?m9Q>g17|Z(BZj?(TaqD# zjg_Gul-r;=i6ILj%mK=0HCzig7BbXu!}A$O4J4a!g0mSHG@BKG!U38NVaX625+Hwp z^VLcgSUQEtgW`H6CpcV->Oo-$&PzoVAfv!J1{5ujr~pO4EtbrZ)ZCRk;H*+q4N_GD zB0zHBEK*bp;?{wPdJuu;oCc5>I3~aZC^g*TFH6kHOi3)sRLKWr+th-qpj-h;*9=Tt zOdO1Sj9iRdOg!MCmxFckCvgYHw8_R2sr5AsSu81>C9E|}HLT4{DO_O4-OI`dDyCRc zcxo6Hu+=cvur6R<$gmJyo}-2VRMdgwYFMGNyuHj|y<8w&DSRpXAbtrbj(9-L&e#BO zfy!IMCdrT@AX(2)%U;8f1>p)pqnR1p-eO8&sO8M#DdDT(XlATomu9HtEa6WPS|Cuv zR>N7tQNzBFrJ0eDq3}|Qa0*ik1HuY$%NEp@6-?m*o0L?;kR_DDAr5K>3zY~j5CQW= zL>Lw_a)O%j!3>(Bev{X$nCgN<8XT&)a_xIi;L9;EFjSdO{v|FoIbW7@vX@k#MiWRJ zug;Gk6>^gmrOh;2K}@`gKY>($6=%1Bn0OU`0jU7FsR)!vqd3YEbK-OIi;F{w`ao4R zJBXiHnykrGGy$Y?A}IeemnG&D^?<nG)D6n#MV*riMMNgci>lNUE_(PG85oMdc3Of8 z4N-8>Q#1)=;A9X1E_lEt3mb?FFJ5ke+cx0R-XA0}9YmlPP~cQQ1tc>SL<E8ekTY+w z6(#1S<mZCgE~zP+T;LoOB>?gts3wTdFD)r3Eh#PnN8=2Tu9+ZW7KoS)B0z?KGjY)z z5EmS(U;><pL|~a{GM_Au<U3Ff0yUl(7$ulQ7)6*w82MPG7!_D1o6Gv47IVueD&}ff zQkZ&~z-<en3%Ytxk(R<venF?n1g^Cq<qbHzz!f66jsu4hC@DdLY91)4=7WfZAOaj- zhF}sLYMdph#U<dHSd5W@p%_}`fNElJnFDF=aWGE~RnX-JHAp#NO*S{i$rBZPCjV4W zt!KAks9^&YBQ6ZFS}{zu?6n*v95w8-8B#cEIA$}<Wd`N+8jb~=3mF!0Eo7)+t6{fc zD6B%{`B=YNQ16qshEtLug%i>u%YyK@BpGVCYq(OlvzdzSr7+a;fXfu_X2u#WX@**! z5{?v}=mq>WoHaZ(+%;U#LS;h=FOEV*poR(5afFs@f(wLd*wPsnGS>3e@GcOpVaO6m z5s+qB$OtW01VQDBkzNtu;z1jhp{fig=V;1K_L35r{8HJNh_XTltP@nQgc2?*^uUTi z@d%FKc2EJq0gl1M(&Q>;1zQDZ69lA$5k!zuDtN(6zr_MF6I{SlaY5R6ddbCQn%K(; zc<E3S04keULG@#C5vTyT#hhGFdW$u$G`FC#z6iaX067m_LV!w<TTBHhw>VuAOA_6R z5_3~;v6bY<Cl{9$IfBY3K1fl)2Pq-qAucK|LM<b1F(;NL7iEK-z>=L>S$vDFEHS4v zwYUgUb|izW0~eH_!s8Yj#1Qm`5g*9o@$g2G7AQ|>fx0V<&@zNsfKh;jgI$18h*5}% zj}Zhp7}=P37{wUHm?o!bMmm57b&5cpu_93GL6ZsWc8CBt+kpvi-V;gA&r8cpFD*)q z&o3y+%+0JyEh-M2{9n^>vZnSF#@fmIv@O6{Rg>`+S8`%OW=UcWNDZzAWh^-UCaddI z2!O1<#gYdKNtUeq%)H5Kbvz6~jr*cqpa9}ZPRvcsi3iKGf_rOd0a0Yjz`zi`SxvW( z(Fc^&{4_alv1jHL7o;Ya+~P`4Edg~i3sRGdK;6SzEFfOdPEZJgEH2szV(kT4#8z66 zl30?e$qosH!;|z4M8E+6CP1N3bdZ68A#<{een35_FHpm<fT4zQ0pmi38m5IzwJaq} zS<E#o3s@F1EMQ&8Py*`Cu{Sfq=oH3Y=2})zmu(>fBSQ&S7IzIRXs9fOrI&dD&q4-B zYX#KV%i>$WU&B0`A%$%&Qw_)zfrX%!4s!~7Hd9eY3C9A#8d-)cAyAn$mkCq=lyEE% zu3-bQz@rMBleG-=CdV3xyA`bjwI4Y2lXK%iiA7VS$QR^YaLaKmh_wzxtOgNlKm<J5 z-(oJvDY?Z4iSfyM4QzBk9=^q1oLH8cmYEJt>bF>vbMlMf;f1IGdnd~oO6h<@4NQQ8 z4wP}qK}l2qoJ0i}RTw!KMVRH7I6%GfBttbu{>f7fRqH{KRMZG^%v=z$2t+Ig5i3B% z77(!wM1Ue493Mp>lW*}t#{-g6b8=u~0$|JYK#IXG5N2Rt0QKIALA4GCGfyFhI)?;@ zB!>ov2!}cc8;1}FUy<<S03#2^*vT`E%q+KpvNLm$703XPt8TFtXCxM+!g9%dkT}?M zusvWaK+4(}C$kxEss~AHGJ})SE!NDug3^*(%=x7yXleHrds==`d16sY7bs{z*^z;b zhp~z$EHkxS4=Sa}R&*O=*BuaX7es)pDS8ZI-2)L|`@sa*IbfIg#WFB3w89MKVB}#b z0*PoE-(oJVEG`1Y_AREof}$fJ1NMQ4rQkeJkeHW}SX=~398v7WsfkJXMR`S+874QF zDApHU1eM|3zLoy45}*j=;wU!XN}v4V;v!II>=t)vNoG#59+)fwwaB74OEPnF^+3(U zB6g5PplZ2@6T||I(}SCHw>V1ji;^?+^74uzKz0bjxO!=&dC4W2`FV*sw<hbGa&z5c zbSnZ4!xV8(PBGQ52aU|!;s&cO%Ph{!&jXJ{MF}B^=z+(bii`L`{^KbrO3W)x%P-1J zEh;VoCB-7pNLCbga$-SAX|7&kNl8(W5J)AzPiA^X38?Syo1apelUf8yKDXGQ98g#m zNif8KhU-8+0f#;);BRr*<bu`PfvU7(3kC)T@W=q9Kg7tv$iv9P%mX4Bc^E~QL3|;Q TD3}Ej7hvLH5~^pYXGjGA2Ue+& diff --git a/models/model_interface.py b/models/model_interface.py index 60b5cc7..b3a561e 100755 --- a/models/model_interface.py +++ b/models/model_interface.py @@ -7,6 +7,8 @@ import pandas as pd import seaborn as sns from pathlib import Path from matplotlib import pyplot as plt +import cv2 +from PIL import Image #----> from MyOptimizer import create_optimizer @@ -20,7 +22,10 @@ import torch import torch.nn as nn import torch.nn.functional as F import torchmetrics +from torchmetrics.functional import stat_scores from torch import optim as optim +# from sklearn.metrics import roc_curve, auc, roc_curve_score + #----> import pytorch_lightning as pl @@ -29,6 +34,10 @@ from torchvision import models from torchvision.models import resnet from transformers import AutoFeatureExtractor, ViTModel +from pytorch_grad_cam import GradCAM, EigenGradCAM +from pytorch_grad_cam.utils.image import show_cam_on_image +from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget + from captum.attr import LayerGradCam class ModelInterface(pl.LightningModule): @@ -41,11 +50,20 @@ class ModelInterface(pl.LightningModule): self.loss = create_loss(loss) # self.asl = AsymmetricLossSingleLabel() # self.loss = LabelSmoothingCrossEntropy(smoothing=0.1) - # self.loss = + # print(self.model) + + + # self.ecam = EigenGradCAM(model = self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform) self.optimizer = optimizer self.n_classes = model.n_classes - self.log_path = kargs['log'] + print(self.n_classes) + self.save_path = kargs['log'] + if Path(self.save_path).parts[3] == 'tcmr': + temp = list(Path(self.save_path).parts) + # print(temp) + temp[3] = 'tcmr_viral' + self.save_path = '/'.join(temp) #---->acc self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] @@ -53,6 +71,7 @@ class ModelInterface(pl.LightningModule): #---->Metrics if self.n_classes > 2: self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted') + metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes, average='micro'), torchmetrics.CohenKappa(num_classes = self.n_classes), @@ -67,6 +86,7 @@ class ModelInterface(pl.LightningModule): else : self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'weighted') + metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2, average = 'micro'), torchmetrics.CohenKappa(num_classes = 2), @@ -76,6 +96,8 @@ class ModelInterface(pl.LightningModule): num_classes = 2), torchmetrics.Precision(average = 'macro', num_classes = 2)]) + self.PRC = torchmetrics.PrecisionRecallCurve(num_classes = self.n_classes) + # self.pr_curve = torchmetrics.BinnedPrecisionRecallCurve(num_classes = self.n_classes, thresholds=10) self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes) self.valid_metrics = metrics.clone(prefix = 'val_') self.test_metrics = metrics.clone(prefix = 'test_') @@ -146,28 +168,40 @@ class ModelInterface(pl.LightningModule): nn.Linear(1024, self.out_features), nn.ReLU(), ) + # print(self.model_ft[0].features[-1]) + # print(self.model_ft) + if model.name == 'TransMIL': + target_layers = [self.model.layer2.norm] # 32x32 + # target_layers = [self.model_ft[0].features[-1]] # 32x32 + self.cam = GradCAM(model=self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform) #, reshape_transform=self.reshape_transform + # self.cam_ft = GradCAM(model=self.model, target_layers = target_layers_ft, use_cuda=True) #, reshape_transform=self.reshape_transform + else: + target_layers = [self.model.attention_weights] + self.cam = GradCAM(model = self.model, target_layers = target_layers, use_cuda=True) + + def forward(self, x): + + feats = self.model_ft(x).unsqueeze(0) + return self.model(feats) + + def step(self, input): + + input = input.squeeze(0).float() + logits = self(input) + + Y_hat = torch.argmax(logits, dim=1) + Y_prob = F.softmax(logits, dim=1) + + return logits, Y_prob, Y_hat def training_step(self, batch, batch_idx): #---->inference - data, label, _ = batch + + input, label, _= batch label = label.float() - data = data.squeeze(0).float() - # print(data) - # print(data.shape) - if self.backbone == 'dino': - features = self.model_ft(**data) - features = features.last_hidden_state - else: - features = self.model_ft(data) - features = features.unsqueeze(0) - # print(features.shape) - # features = features.squeeze() - results_dict = self.model(data=features) - # results_dict = self.model(data=data, label=label) - logits = results_dict['logits'] - Y_prob = results_dict['Y_prob'] - Y_hat = results_dict['Y_hat'] + + logits, Y_prob, Y_hat = self.step(input) #---->loss loss = self.loss(logits, label) @@ -183,6 +217,14 @@ class ModelInterface(pl.LightningModule): # Y = int(label[0]) self.data[Y]["count"] += 1 self.data[Y]["correct"] += (int(Y_hat) == Y) + self.log('loss', loss, prog_bar=True, on_epoch=True, logger=True) + + if self.current_epoch % 10 == 0: + + grid = torchvision.utils.make_grid(images) + # log input images + # self.loggers[0].experiment.add_figure(f'{stage}/input', , self.current_epoch) + return {'loss': loss, 'Y_prob': Y_prob, 'Y_hat': Y_hat, 'label': label} @@ -212,18 +254,10 @@ class ModelInterface(pl.LightningModule): def validation_step(self, batch, batch_idx): - data, label, _ = batch - + input, label, _ = batch label = label.float() - data = data.squeeze(0).float() - features = self.model_ft(data) - features = features.unsqueeze(0) - - results_dict = self.model(data=features) - logits = results_dict['logits'] - Y_prob = results_dict['Y_prob'] - Y_hat = results_dict['Y_hat'] - + + logits, Y_prob, Y_hat = self.step(input) #---->acc log # Y = int(label[0][1]) @@ -237,18 +271,23 @@ class ModelInterface(pl.LightningModule): def validation_epoch_end(self, val_step_outputs): logits = torch.cat([x['logits'] for x in val_step_outputs], dim = 0) - # probs = torch.cat([x['Y_prob'] for x in val_step_outputs], dim = 0) probs = torch.cat([x['Y_prob'] for x in val_step_outputs]) max_probs = torch.stack([x['Y_hat'] for x in val_step_outputs]) - # target = torch.stack([x['label'] for x in val_step_outputs], dim = 0) target = torch.cat([x['label'] for x in val_step_outputs]) target = torch.argmax(target, dim=1) #----> # logits = logits.long() # target = target.squeeze().long() # logits = logits.squeeze(0) + if len(target.unique()) != 1: + self.log('val_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True) + else: + self.log('val_auc', 0.0, prog_bar=True, on_epoch=True, logger=True) + + + self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True) - self.log('val_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True) + # print(max_probs.squeeze(0).shape) # print(target.shape) @@ -276,24 +315,61 @@ class ModelInterface(pl.LightningModule): random.seed(self.count*50) def test_step(self, batch, batch_idx): - - data, label, _ = batch + torch.set_grad_enabled(True) + data, label, name = batch label = label.float() + # logits, Y_prob, Y_hat = self.step(data) + # print(data.shape) data = data.squeeze(0).float() - features = self.model_ft(data) - features = features.unsqueeze(0) + logits = self(data).detach() + + Y = torch.argmax(label) + Y_hat = torch.argmax(logits, dim=1) + Y_prob = F.softmax(logits, dim = 1) + + #----> Get Topk tiles + + target = [ClassifierOutputTarget(Y)] + + data_ft = self.model_ft(data).unsqueeze(0).float() + # data_ft = self.model_ft(data).unsqueeze(0).float() + # print(data_ft.shape) + # print(target) + grayscale_cam = self.cam(input_tensor=data_ft, targets=target) + # grayscale_ecam = self.ecam(input_tensor=data_ft, targets=target) + + # print(grayscale_cam) - results_dict = self.model(data=features, label=label) - logits = results_dict['logits'] - Y_prob = results_dict['Y_prob'] - Y_hat = results_dict['Y_hat'] + summed = torch.mean(torch.Tensor(grayscale_cam), dim=2) + print(summed) + print(summed.shape) + topk_tiles, topk_indices = torch.topk(summed.squeeze(0), 5, dim=0) + topk_data = data[topk_indices].detach() + + # target_ft = + # grayscale_cam_ft = self.cam_ft(input_tensor=data, ) + # for i in range(data.shape[0]): + + # vis_img = data[i, :, :, :].cpu().numpy() + # vis_img = np.transpose(vis_img, (1,2,0)) + # print(vis_img.shape) + # cam_img = grayscale_cam.squeeze(0) + # cam_img = self.reshape_transform(grayscale_cam) + + # print(cam_img.shape) + + # visualization = show_cam_on_image(vis_img, cam_img, use_rgb=True) + # visualization = ((visualization/visualization.max())*255.0).astype(np.uint8) + # print(visualization) + # cv2.imwrite(f'{test_path}/{Y}/{name}/gradcam.jpg', cam_img) #---->acc log Y = torch.argmax(label) self.data[Y]["count"] += 1 self.data[Y]["correct"] += (Y_hat.item() == Y) - return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label} + return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': name, 'topk_data': topk_data} # + # return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': name} #, 'topk_data': topk_data def test_epoch_end(self, output_results): probs = torch.cat([x['Y_prob'] for x in output_results]) @@ -301,7 +377,8 @@ class ModelInterface(pl.LightningModule): # target = torch.stack([x['label'] for x in output_results], dim = 0) target = torch.cat([x['label'] for x in output_results]) target = torch.argmax(target, dim=1) - + patients = [x['name'] for x in output_results] + topk_tiles = [x['topk_data'] for x in output_results] #----> auc = self.AUROC(probs, target.squeeze()) metrics = self.test_metrics(max_probs.squeeze() , target) @@ -312,9 +389,41 @@ class ModelInterface(pl.LightningModule): # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True) - # print(max_probs.squeeze(0).shape) - # print(target.shape) - # self.log_dict(metrics, logger = True) + #---->get highest scoring patients for each class + test_path = Path(self.save_path) / 'most_predictive' + topk, topk_indices = torch.topk(probs.squeeze(0), 5, dim=0) + for n in range(self.n_classes): + print('class: ', n) + topk_patients = [patients[i[n]] for i in topk_indices] + topk_patient_tiles = [topk_tiles[i[n]] for i in topk_indices] + for x, p, t in zip(topk, topk_patients, topk_patient_tiles): + print(p, x[n]) + patient = p[0] + outpath = test_path / str(n) / patient + outpath.mkdir(parents=True, exist_ok=True) + for i in range(len(t)): + tile = t[i] + tile = tile.cpu().numpy().transpose(1,2,0) + tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255 + tile = tile.astype(np.uint8) + img = Image.fromarray(tile) + + img.save(f'{test_path}/{n}/{patient}/{i}_gradcam.jpg') + + + + #----->visualize top predictive tiles + + + + + # img = img.squeeze(0).cpu().numpy() + # img = np.transpose(img, (1,2,0)) + # # print(img) + # # print(grayscale_cam.shape) + # visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True) + + for keys, values in metrics.items(): print(f'{keys} = {values}') metrics[keys] = values.cpu().numpy() @@ -329,16 +438,35 @@ class ModelInterface(pl.LightningModule): print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count)) self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] + #---->plot auroc curve + # stats = stat_scores(probs, target, reduce='macro', num_classes=self.n_classes) + # fpr = {} + # tpr = {} + # for n in self.n_classes: + + # fpr, tpr, thresh = roc_curve(target.cpu().numpy(), probs.cpu().numpy()) + #[tp, fp, tn, fn, tp+fn] + + self.log_confusion_matrix(probs, target, stage='test') #----> result = pd.DataFrame([metrics]) - result.to_csv(self.log_path / 'result.csv') + result.to_csv(Path(self.save_path) / f'test_result.csv', mode='a', header=not Path(self.save_path).exists()) + + # with open(f'{self.save_path}/test_metrics.txt', 'a') as f: + + # f.write([metrics]) def configure_optimizers(self): # optimizer_ft = optim.Adam(self.model_ft.parameters(), lr=self.optimizer.lr*0.1) optimizer = create_optimizer(self.optimizer, self.model) return optimizer + def reshape_transform(self, tensor, h=32, w=32): + result = tensor[:, 1:, :].reshape(tensor.size(0), h, w, tensor.size(2)) + result = result.transpose(2,3).transpose(1,2) + # print(result.shape) + return result def load_model(self): name = self.hparams.model.name @@ -372,18 +500,33 @@ class ModelInterface(pl.LightningModule): args1.update(other_args) return Model(**args1) + def log_image(self, tensor, stage, name): + + tile = tile.cpu().numpy().transpose(1,2,0) + tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255 + tile = tile.astype(np.uint8) + img = Image.fromarray(tile) + self.loggers[0].experiment.add_figure(f'{stage}/{name}', img, self.current_epoch) + + def log_confusion_matrix(self, max_probs, target, stage): confmat = self.confusion_matrix(max_probs.squeeze(), target) + print(confmat) df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes)) - plt.figure() + # plt.figure() fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure() # plt.close(fig_) - # plt.savefig(f'{self.log_path}/cm_e{self.current_epoch}') - self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch) + # plt.savefig(f'{self.save_path}/cm_e{self.current_epoch}') + - if stage == 'test': - plt.savefig(f'{self.log_path}/cm_test') - plt.close(fig_) + if stage == 'train': + # print(self.save_path) + # plt.savefig(f'{self.save_path}/cm_test') + + self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch) + else: + fig_.savefig(f'{self.save_path}/cm_test.png', dpi=400) + # plt.close(fig_) # self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch) class View(nn.Module): diff --git a/test_visualize.py b/test_visualize.py new file mode 100644 index 0000000..7ef56d3 --- /dev/null +++ b/test_visualize.py @@ -0,0 +1,148 @@ +import argparse +from pathlib import Path +import numpy as np +import glob + +from sklearn.model_selection import KFold + +from datasets.data_interface import DataInterface, MILDataModule, CrossVal_MILDataModule +from models.model_interface import ModelInterface +import models.vision_transformer as vits +from utils.utils import * + +# pytorch_lightning +import pytorch_lightning as pl +from pytorch_lightning import Trainer +import torch +from train_loop import KFoldLoop + +#--->Setting parameters +def make_parse(): + parser = argparse.ArgumentParser() + parser.add_argument('--stage', default='train', type=str) + parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str) + parser.add_argument('--version', default=0,type=int) + parser.add_argument('--epoch', default='0',type=str) + parser.add_argument('--gpus', default = 2, type=int) + parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str) + parser.add_argument('--fold', default = 0) + parser.add_argument('--bag_size', default = 1024, type=int) + + args = parser.parse_args() + return args + +#---->main +def main(cfg): + + torch.set_num_threads(16) + + #---->Initialize seed + pl.seed_everything(cfg.General.seed) + + #---->load loggers + # cfg.load_loggers = load_loggers(cfg) + + # print(cfg.load_loggers) + # save_path = Path(cfg.load_loggers[0].log_dir) + + #---->load callbacks + # cfg.callbacks = load_callbacks(cfg, save_path) + + home = Path.cwd().parts[1] + DataInterface_dict = { + 'data_root': cfg.Data.data_dir, + 'label_path': cfg.Data.label_file, + 'batch_size': cfg.Data.train_dataloader.batch_size, + 'num_workers': cfg.Data.train_dataloader.num_workers, + 'n_classes': cfg.Model.n_classes, + 'backbone': cfg.Model.backbone, + 'bag_size': cfg.Data.bag_size, + } + + dm = MILDataModule(**DataInterface_dict) + + + #---->Define Model + ModelInterface_dict = {'model': cfg.Model, + 'loss': cfg.Loss, + 'optimizer': cfg.Optimizer, + 'data': cfg.Data, + 'log': cfg.log_path, + 'backbone': cfg.Model.backbone, + } + model = ModelInterface(**ModelInterface_dict) + + #---->Instantiate Trainer + trainer = Trainer( + num_sanity_val_steps=0, + # logger=cfg.load_loggers, + # callbacks=cfg.callbacks, + max_epochs= cfg.General.epochs, + min_epochs = 200, + gpus=cfg.General.gpus, + # gpus = [0,2], + # strategy='ddp', + amp_backend='native', + # amp_level=cfg.General.amp_level, + precision=cfg.General.precision, + accumulate_grad_batches=cfg.General.grad_acc, + # fast_dev_run = True, + + # deterministic=True, + check_val_every_n_epoch=10, + ) + + #---->train or test + log_path = cfg.log_path + # print(log_path) + # log_path = Path('lightning_logs/2/checkpoints') + model_paths = list(log_path.glob('*.ckpt')) + + if cfg.epoch == 'last': + model_paths = [str(model_path) for model_path in model_paths if f'last' in str(model_path)] + else: + model_paths = [str(model_path) for model_path in model_paths if f'epoch={cfg.epoch}' in str(model_path)] + + # model_paths = [str(model_path) for model_path in model_paths if f'epoch={cfg.epoch}' in str(model_path)] + # model_paths = [f'lightning_logs/0/.ckpt'] + # model_paths = [f'{log_path}/last.ckpt'] + if not model_paths: + print('No Checkpoints vailable!') + for path in model_paths: + # with open(f'{log_path}/test_metrics.txt', 'w') as f: + # f.write(str(path) + '\n') + print(path) + new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg) + trainer.test(model=new_model, datamodule=dm) + + # Top 5 scoring patches for patient + # GradCam + + +if __name__ == '__main__': + + args = make_parse() + cfg = read_yaml(args.config) + + #---->update + cfg.config = args.config + cfg.General.gpus = [args.gpus] + cfg.General.server = args.stage + cfg.Data.fold = args.fold + cfg.Loss.base_loss = args.loss + cfg.Data.bag_size = args.bag_size + cfg.version = args.version + cfg.epoch = args.epoch + + log_path = Path(cfg.General.log_path) / str(Path(cfg.config).parent) + Path(cfg.General.log_path).mkdir(exist_ok=True, parents=True) + log_name = f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}' + task = '_'.join(Path(cfg.config).name[:-5].split('_')[2:]) + # task = Path(cfg.config).name[:-5].split('_')[2:][0] + cfg.log_path = log_path / f'{cfg.Model.name}' / task / log_name / 'lightning_logs' / f'version_{cfg.version}' / 'checkpoints' + + + + #---->main + main(cfg) + \ No newline at end of file diff --git a/train.py b/train.py index 036d5ed..5e30394 100644 --- a/train.py +++ b/train.py @@ -5,7 +5,7 @@ import glob from sklearn.model_selection import KFold -from datasets.data_interface import DataInterface, MILDataModule +from datasets.data_interface import DataInterface, MILDataModule, CrossVal_MILDataModule from models.model_interface import ModelInterface import models.vision_transformer as vits from utils.utils import * @@ -13,18 +13,23 @@ from utils.utils import * # pytorch_lightning import pytorch_lightning as pl from pytorch_lightning import Trainer -from pytorch_lightning.loops import KFoldLoop import torch +from train_loop import KFoldLoop #--->Setting parameters def make_parse(): parser = argparse.ArgumentParser() parser.add_argument('--stage', default='train', type=str) parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str) + parser.add_argument('--version', default=2,type=int) parser.add_argument('--gpus', default = 2, type=int) parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str) parser.add_argument('--fold', default = 0) parser.add_argument('--bag_size', default = 1024, type=int) + parser.add_argument('--resume_training', action='store_true') + # parser.add_argument('--ckpt_path', default = , type=str) + + args = parser.parse_args() return args @@ -39,9 +44,10 @@ def main(cfg): #---->load loggers cfg.load_loggers = load_loggers(cfg) # print(cfg.load_loggers) + save_path = Path(cfg.load_loggers[0].log_dir) #---->load callbacks - cfg.callbacks = load_callbacks(cfg) + cfg.callbacks = load_callbacks(cfg, save_path) #---->Define Data # DataInterface_dict = {'train_batch_size': cfg.Data.train_dataloader.batch_size, @@ -58,11 +64,12 @@ def main(cfg): 'batch_size': cfg.Data.train_dataloader.batch_size, 'num_workers': cfg.Data.train_dataloader.num_workers, 'n_classes': cfg.Model.n_classes, - 'backbone': cfg.Model.backbone, 'bag_size': cfg.Data.bag_size, } - dm = MILDataModule(**DataInterface_dict) + if cfg.Data.cross_val: + dm = CrossVal_MILDataModule(**DataInterface_dict) + else: dm = MILDataModule(**DataInterface_dict) #---->Define Model @@ -82,9 +89,9 @@ def main(cfg): callbacks=cfg.callbacks, max_epochs= cfg.General.epochs, min_epochs = 200, - # gpus=cfg.General.gpus, - gpus = [2,3], - strategy='ddp', + gpus=cfg.General.gpus, + # gpus = [0,2], + # strategy='ddp', amp_backend='native', # amp_level=cfg.General.amp_level, precision=cfg.General.precision, @@ -96,12 +103,31 @@ def main(cfg): ) #---->train or test + if cfg.resume_training: + last_ckpt = log_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.version}' / 'last.ckpt' + trainer.fit(model = model, datamodule = dm, ckpt_path=last_ckpt) + if cfg.General.server == 'train': - trainer.fit_loop = KFoldLoop(3, trainer.fit_loop, ) - trainer.fit(model = model, datamodule = dm) + + # k-fold cross validation loop + if cfg.Data.cross_val: + internal_fit_loop = trainer.fit_loop + trainer.fit_loop = KFoldLoop(cfg.Data.nfold, export_path = cfg.log_path, **ModelInterface_dict) + trainer.fit_loop.connect(internal_fit_loop) + trainer.fit(model, dm) + else: + trainer.fit(model = model, datamodule = dm) else: - model_paths = list(cfg.log_path.glob('*.ckpt')) + log_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.version}' + + test_path = Path(log_path) / 'test' + for n in range(cfg.Model.n_classes): + n_output_path = test_path / str(n) + n_output_path.mkdir(parents=True, exist_ok=True) + # print(cfg.log_path) + model_paths = list(log_path.glob('*.ckpt')) model_paths = [str(model_path) for model_path in model_paths if 'epoch' in str(model_path)] + # model_paths = [f'{log_path}/epoch=279-val_loss=0.4009.ckpt'] for path in model_paths: print(path) new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg) @@ -120,6 +146,16 @@ if __name__ == '__main__': cfg.Data.fold = args.fold cfg.Loss.base_loss = args.loss cfg.Data.bag_size = args.bag_size + cfg.version = args.version + + log_path = Path(cfg.General.log_path) / str(Path(cfg.config).parent) + Path(cfg.General.log_path).mkdir(exist_ok=True, parents=True) + log_name = f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}' + task = '_'.join(Path(cfg.config).name[:-5].split('_')[2:]) + # task = Path(cfg.config).name[:-5].split('_')[2:][0] + cfg.log_path = log_path / f'{cfg.Model.name}' / task / log_name + + #---->main diff --git a/train_loop.py b/train_loop.py new file mode 100644 index 0000000..f923681 --- /dev/null +++ b/train_loop.py @@ -0,0 +1,212 @@ +from pytorch_lightning import LightningModule +import torch +import torch.nn.functional as F +from torchmetrics.classification.accuracy import Accuracy +import os.path as osp +from abc import ABC, abstractmethod +from copy import deepcopy +from pytorch_lightning import LightningModule +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.loops.fit_loop import FitLoop +from pytorch_lightning.trainer.states import TrainerFn +from datasets.data_interface import BaseKFoldDataModule +from typing import Any, Dict, List, Optional, Type +import torchmetrics +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns + + + +class EnsembleVotingModel(LightningModule): + def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str], n_classes, log_path) -> None: + super().__init__() + # Create `num_folds` models with their associated fold weights + self.n_classes = n_classes + self.log_path = log_path + self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] + self.models = torch.nn.ModuleList([model_cls.load_from_checkpoint(p) for p in checkpoint_paths]) + self.test_acc = Accuracy() + if self.n_classes > 2: + self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted') + metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes, + average='micro'), + torchmetrics.CohenKappa(num_classes = self.n_classes), + torchmetrics.F1Score(num_classes = self.n_classes, + average = 'macro'), + torchmetrics.Recall(average = 'macro', + num_classes = self.n_classes), + torchmetrics.Precision(average = 'macro', + num_classes = self.n_classes), + torchmetrics.Specificity(average = 'macro', + num_classes = self.n_classes)]) + + else : + self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'weighted') + metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2, + average = 'micro'), + torchmetrics.CohenKappa(num_classes = 2), + torchmetrics.F1Score(num_classes = 2, + average = 'macro'), + torchmetrics.Recall(average = 'macro', + num_classes = 2), + torchmetrics.Precision(average = 'macro', + num_classes = 2)]) + self.test_metrics = metrics.clone(prefix = 'test_') + self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes) + + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + # Compute the averaged predictions over the `num_folds` models. + # print(batch[0].shape) + input, label, _ = batch + label = label.float() + input = input.squeeze(0).float() + + + logits = torch.stack([m(input) for m in self.models]).mean(0) + Y_hat = torch.argmax(logits, dim=1) + Y_prob = F.softmax(logits, dim = 1) + # #---->acc log + Y = torch.argmax(label) + self.data[Y]["count"] += 1 + self.data[Y]["correct"] += (Y_hat.item() == Y) + + return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label} + + def test_epoch_end(self, output_results): + probs = torch.cat([x['Y_prob'] for x in output_results]) + max_probs = torch.stack([x['Y_hat'] for x in output_results]) + # target = torch.stack([x['label'] for x in output_results], dim = 0) + target = torch.cat([x['label'] for x in output_results]) + target = torch.argmax(target, dim=1) + + #----> + auc = self.AUROC(probs, target.squeeze()) + metrics = self.test_metrics(max_probs.squeeze() , target) + + + # metrics = self.test_metrics(max_probs.squeeze() , torch.argmax(target.squeeze(), dim=1)) + metrics['test_auc'] = auc + + # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True) + + # print(max_probs.squeeze(0).shape) + # print(target.shape) + # self.log_dict(metrics, logger = True) + for keys, values in metrics.items(): + print(f'{keys} = {values}') + metrics[keys] = values.cpu().numpy() + #---->acc log + for c in range(self.n_classes): + count = self.data[c]["count"] + correct = self.data[c]["correct"] + if count == 0: + acc = None + else: + acc = float(correct) / count + print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count)) + self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] + + self.log_confusion_matrix(probs, target, stage='test') + #----> + result = pd.DataFrame([metrics]) + result.to_csv(self.log_path / 'result.csv') + + + def log_confusion_matrix(self, max_probs, target, stage): + confmat = self.confusion_matrix(max_probs.squeeze(), target) + df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes)) + plt.figure() + fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure() + # plt.close(fig_) + # plt.savefig(f'{self.log_path}/cm_e{self.current_epoch}') + self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch) + + if stage == 'test': + plt.savefig(f'{self.log_path}/cm_test') + plt.close(fig_) + +class KFoldLoop(Loop): + def __init__(self, num_folds: int, export_path: str, **kargs) -> None: + super().__init__() + self.num_folds = num_folds + self.current_fold: int = 0 + self.export_path = export_path + self.n_classes = kargs["model"].n_classes + self.log_path = kargs["log"] + + @property + def done(self) -> bool: + return self.current_fold >= self.num_folds + + def connect(self, fit_loop: FitLoop) -> None: + self.fit_loop = fit_loop + + def reset(self) -> None: + """Nothing to reset in this loop.""" + + def on_run_start(self, *args: Any, **kwargs: Any) -> None: + """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the + model.""" + assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) + self.trainer.datamodule.setup_folds(self.num_folds) + self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict()) + + def on_advance_start(self, *args: Any, **kwargs: Any) -> None: + """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance.""" + print(f"STARTING FOLD {self.current_fold}") + assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) + self.trainer.datamodule.setup_fold_index(self.current_fold) + + def advance(self, *args: Any, **kwargs: Any) -> None: + """Used to the run a fitting and testing on the current hold.""" + self._reset_fitting() # requires to reset the tracking stage. + self.fit_loop.run() + + self._reset_testing() # requires to reset the tracking stage. + self.trainer.test_loop.run() + self.current_fold += 1 # increment fold tracking number. + + def on_advance_end(self) -> None: + """Used to save the weights of the current fold and reset the LightningModule and its optimizers.""" + self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt")) + # restore the original weights + optimizers and schedulers. + self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict) + self.trainer.strategy.setup_optimizers(self.trainer) + self.replace(fit_loop=FitLoop) + + def on_run_end(self) -> None: + """Used to compute the performance of the ensemble model on the test set.""" + checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)] + voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths, n_classes=self.n_classes, log_path=self.log_path) + voting_model.trainer = self.trainer + # This requires to connect the new model and move it the right device. + self.trainer.strategy.connect(voting_model) + self.trainer.strategy.model_to_device() + self.trainer.test_loop.run() + + def on_save_checkpoint(self) -> Dict[str, int]: + return {"current_fold": self.current_fold} + + def on_load_checkpoint(self, state_dict: Dict) -> None: + self.current_fold = state_dict["current_fold"] + + def _reset_fitting(self) -> None: + self.trainer.reset_train_dataloader() + self.trainer.reset_val_dataloader() + self.trainer.state.fn = TrainerFn.FITTING + self.trainer.training = True + + def _reset_testing(self) -> None: + self.trainer.reset_test_dataloader() + self.trainer.state.fn = TrainerFn.TESTING + self.trainer.testing = True + + def __getattr__(self, key) -> Any: + # requires to be overridden as attributes of the wrapped loop are being accessed. + if key not in self.__dict__: + return getattr(self.fit_loop, key) + return self.__dict__[key] + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__dict__.update(state) \ No newline at end of file diff --git a/utils/__pycache__/utils.cpython-39.pyc b/utils/__pycache__/utils.cpython-39.pyc index 704aade8ba3e8bcbedf237b3fde5db80b84fcd32..33d3d005a578948dd3d5b58938a815f86c4e92f5 100644 GIT binary patch delta 2565 zcmew*wN#!jk(ZZ?fq{Xc{>JR2w;~hyWa=+5GBBhtq%h_%<T6GvGJ@DlIZV0CQOvn4 zQ7m9Sa}H}1YYIaOa}HZBdlUy)j3tLNmn(`3%x2Bu&gF^X1+&?5_;UH9_`z)U9D!WH zC_ylrBS$D#I7&EIBuWI#=gbk!6^jyMtOr@hl_Q=j5hal;86}x36(t2$z?~zWD;*^b zX7l98<jO|Lg4w(|a=G$R@?bV!jzX?tlp>hTpQDtk9HpGA5~Y%>8l{@67NwS}9;FVJ zW6sft(nw)Q5y;U5$ulyfsHF&|2(>UqY1OANqzLDT=W0jkFfycwq=>dKMCqn5q=@C{ z<?2W2=Nd#AfQ=B(G0ZiJG6J(Da*T6LqD&YWQY2HPS{R~C!C?anBePs{CI&_ZXoyIs z$h0s-S)?$e$U-4=irgHw6lRc~6y}tA*%XKXNP&EcLJBh)UlGb@PLWN)F0TYOP<akp zlx2!aifRi(lvN6IFoUN0OHfQ}GTvedNG!?FWV*%d=;VA$I5DZXq$n}DBsnLsxHz{y zwIm}y#ZQy*7OS^geol%e<1MbV(!Au7%>2B>98JbsJVl9lDfzka#RWN;B_LUs#F9jx z{KS;hB2A`S>>yroYRN6O;L@bxRFDD=$K>SFqQvA%P3Bv|KAGtmC3%^7=^%B!`6;D2 zskiuxQ&UsoQ_E6|DoZlzGxO4Kv4<2TX6B_9X)@m8hwFf-207m+Kfgee@fN#VW(kzZ z3Dw}1r^$MY*)gy37K=+}a>*?gpUmQtTO9rc5a-`w38^ed)nqIZVPIfLW&{NX6f-k0 zFmNz1FgQ<+W2|M9Vqjn>5}EvvQHoJ<@*hUGdNBqDhGI65T1Gxb0mdR_1_lO@3v%+) z(^HFzF^mE!bq47vV_;yYVaQ@gVa#SO(kfxhVya=tVoqU7VNPM`Wv*o`VO+oh5?RPt z%amtQ!U|$HGo~;~Gt@A|GsDz7rZ9rFu=cWmv}A#`q%fwi^)l5mm9V9-gS5?Nn9J16 z$jDH`Si_Xg6wIK>;dhIrB(=CC#7&cF@;hdE6Xsj&WvNBQnfZA|Y9Jq~gMCq?!N9;! z#iy&Qt83?zpRSOSS!AW4$$pEqxFj(>b@Fr;1$B;`{Pg&O#FC6#oX){vK44E5F)=VO z++r;#%FHX#WWL3moR&WM4U2@UBFJPgp#)OElarsA5)Tb0Sx}J3fr62fi;07Y<sTah zAF}`>2O|?Bh|R&s^q-AIfVs$EawKbDJ;(x(S`Y>WC=V!pKqhQqU|>jRsAZ}F2UIO{ z4RaPl3KJ-VdYKp*N*GgEQa}O81P*An6!u=$T9z871uQiz3mF+1N?1$SQaDmLQ#eyN zQn<j7yO%wMr<bvo6|9OCq^gFYK8q=fJ%u-$sb~X8xP&8xrG&GFxtTGA4{RBK3R?<O zia;+1SUJ-ImJ*H>L8uBLs0v|Z6~PReB7V2HlQUA2vkUSw^Gb?CG?{L3q*i1Wm&E61 z-(oLFEK1EQDZa%~mY5TtGr5<|no)A{Nw!^*R-lk%PsuC-#{oxLW=?8eVs2`Y?c_P^ z(u|Umcd$DIa@}Ih%}&WIDl!2pc2CVqElSKOvH|hf5(^4a^HPe8KyhOX7Aa0GDoZUY zG6Zp$i%W_$*^2Z*Y|i4uvebBxJ(F!Y3=Ew>I>50BCcsh32aeL@#GIU@#N_N^a5@iT zU|?WiVqxN#JdZ<1or!~yjgbife{yhOiWEgo{>brI3}G&pO-^QUNpgN}fnDX~SDZy6 zASq3zTdXDdMadbrm~-;ei()4iaEa)EB#J<8DM|#TU631-Km;g9YqH&9ElSKwPrb#Q zlbTnQJb43`R9pf`HjROSA&RvElo)TZ<m49@7lF*tWG({Broxnf4ZbCuT$Ep29G{w3 zQj}j%84q?s1t{@>N(lx=DMl4WF~*`?kS4Fm8r*BRL1uvsEK-_$pIfHB1Ed(7gqezJ z!5Pg@(;5`VB^f!HNs!#boS2kc1S)!pSU?6PgB&OZwyG=>luh+YiW2jR)AEaQi*Iou zu^B^di5F+*q$U>S>E-69q~^pIr{<)B%jhCjP%Ltk)PurIFE6hMlyi!rL4pEcll4kV zGINUcQW8rNi`YTJ(&)lqIZzy@78QYfdrJbVisICgB2JLAK@EsoDquCasU=03$;Eof zIrWLf#hGcD$%&wnS1%D-1r%|EY!oc01gnnEfz<{@puAng3o?xlMDT-@$zUkc%gN6# zDAr3#EKV&F04Y(xr6jFBvm_p*x=0X*x)MkgqgPy#Sdv;?Bm`2%R#I7znU`K93~~dg zjwliZu|TDNkvNDY1BxNG#FUiG<PvZtP^1D9)&da*AQPlv4w~G-Bg_uYUPVTet9dNg zz}Ze~as!Xd<f}aVY@l3P6gl}lk4z0D6;*-E03{SiB0@?|Aaglva>13Q9Vm$xyD=~@ y@Gyc3P98=W6kz0G5`pt%m^m1E7`fPZ7&(*#K%z_>#vF`7j2s*SjBxmijS~P`KaFPq delta 1941 zcmZ1~|4WK5k(ZZ?fq{YH?7=BX?LrgzWa^n185mL+QW$d>av7r-89{8O9HuCy6owS$ z9Ohh>C>BPB6v-5p6xJ5TDAp8)6t*0;T=pmqMursj6pj{#D9#jy6wVy3T<$3DT%IT% zunAl_yt#Z)d|)<r4u7sdlmH_`3Qr1e3qzD(3PTE0j!={kn64Mj6=7mvWXKhb5@TdY z;Y;CfVTck>VMq~xLZ%eKIczB)=P;!(r3j=z_#k<q6yX$*7)Tz*7lF!yRAQGG1)C=} zhb>AXMLb19p@ktzGKDFaK~w4_C}1=hZ?Oa<mSp&W7%rK~C6oD>Co8fsFfbG`GcYh{ zGT&lND@ZKKxW$r|nUi{pxiU9rB|{PK<U`DEf*@7JAcBFBk5PcJh<~yqi*!9mfSrMX z!I^=9q1b?tfuV*Wiy@1#ma&AXhOwD3i#dfcg)xPxmx+;~gr$V3h9Qf!nK6s4nK6aA zh9QeRi(@v!Tqbb_Fpo2f3&N{qf~(|CVMt+VWs+p51*zvrVa;YPItMYkhDn;CzJ^(v zA%(4mA&WPQ52RU|p@t!zv4lTMpoX!8rG}}QQJkTM8CgXNBPeQnS!-EJ__G9Ys^aWr zu4T>hLRba1i51QKdYA>Qc#LID0UK5#l)_cRki`gcW-}urLk&v}YdR=yxcx#_GTq`x zt;j4ciO<iz#a@tDl$uvke2X#u<?sLh|KDOMNi8n9#gdkvlj5eya*HJ|F*o%Vds%8x zF>_{q-sBT3^7Twb5};U=WME*Z;?vdD)wT1<Pgh9EEV5G26u!moo|>0hl$djiBPTyS z9u&hxObiSRw^(ztQ!<NgF&CE<-C|45&r8cpzr_Zz{T8cleoAW2Esmtb<m{yUywqDP zKKaGPw>Xm$i&Nus@{5bXHr`?_E=f#J<-Wz@l30>hB+kIVFj<;ap`Oz@ILs$MJw3JP z7HdIKW?qRV+b!ngwDclbkiB4MfP4|f1M)>?eqKD7$61mTp99tv#g$xK24xrVg1iPM z_&~aOa`F>XpmN0^;bJ{d;%4Gt6k^n26k*|FR8VH*W8`CG`Nzh>$H@1egINs3gYsGa zv#|&;7wJvj&l*_I2+|3~tPBhcJg~I9je&t7ouQTyl%8vuYM8P>=@^!hnNpZjSZWxu zn6p??7*kkN*m_xOnQIsqu+}gyWMpJ0VJl$=C9)Kb6!sKOFy!iGPvP!mtYrbKVgacF z>5t9gNa4w5D%u1RF5ygJE&(MGhGxbTUU0_YO93S;{$38S2F3-fC7dY&a5aKZH9{$j zDR4Ex44T4zAw`0qm|$mjVq~4{#lB0@6cljmDVar}$iKw_%HE(5FS3{{%OTCkI@yxL z!Ic{vIYn}y&^7`QG9bbnq?0YNpdd9brN{szVh9!~PAw`+Eh^FniO7Qr57r_bke#PE z3=C~Sf}ogzL>)NR_`tE2oS2gXN>RnepaLTd6!J_gOl*_2Id#+-IT$$@nGo<N2M4A| zQNZLn&c|Y4!@=fq+2mvvmn7%s7T9G>F6Sx=0ZC~x-C`}tFG|k1#hjC$UK9dyAJZ+) ziumIEw36J!id&482&WZAFfcG=f}9))B0xz(lkFC3QDR<t>MiD+)V!jo$+p~5#^E5@ zSOx}$DAo#4^1H<X&IVxnn2W%&F)$@XV7CcRZsZb{NC$b91LRdkDMl4WF~*`4kTTcF zTe;V8gLFfkGC7Aw#=IP)7?e2~n2Jh3&Svz}lmZnAB^f!HN#JZ-1Trg%Eiok}Gr0s@ za2N4Up35uFEdVlB14QUeF6UEZ1ILZQ<avA+jE0l1^BHi1V>JMzAZju%zfANkuC&s; z<dV$%yu_TMAdooNWrZLX$a09okempLc@CT0{FKt1R69^W6@#qfVdh}uVdP@tVdPK} QVB}%sVd5~FJds}n0E2$Q-2eap diff --git a/utils/utils.py b/utils/utils.py index 010eaab..814cf1d 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -10,10 +10,11 @@ from torch.utils.data.dataset import Dataset, Subset from torchmetrics.classification.accuracy import Accuracy from pytorch_lightning import LightningDataModule, seed_everything, Trainer -from pytorch_lightning.core.module import LightningModule +from pytorch_lightning import LightningModule from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.trainer.states import TrainerFn +from typing import Any, Dict, List, Optional, Type #---->read yaml import yaml @@ -27,30 +28,30 @@ def read_yaml(fpath=None): from pytorch_lightning import loggers as pl_loggers def load_loggers(cfg): - log_path = cfg.General.log_path - Path(log_path).mkdir(exist_ok=True, parents=True) - log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}' - version_name = Path(cfg.config).name[:-5] + # log_path = cfg.General.log_path + # Path(log_path).mkdir(exist_ok=True, parents=True) + # log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}' + # version_name = Path(cfg.config).name[:-5] #---->TensorBoard if cfg.stage != 'test': - cfg.log_path = Path(log_path) / log_name / version_name / f'fold{cfg.Data.fold}' - tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name), - name = version_name, version = f'fold{cfg.Data.fold}', + + tb_logger = pl_loggers.TensorBoardLogger(cfg.log_path, + # version = f'fold{cfg.Data.fold}' log_graph = True, default_hp_metric = False) #---->CSV - csv_logger = pl_loggers.CSVLogger(log_path+str(log_name), - name = version_name, version = f'fold{cfg.Data.fold}', ) + csv_logger = pl_loggers.CSVLogger(cfg.log_path, + ) # version = f'fold{cfg.Data.fold}', else: - cfg.log_path = Path(log_path) / log_name / version_name / f'test' - tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name), - name = version_name, version = f'test', + cfg.log_path = Path(cfg.log_path) / f'test' + tb_logger = pl_loggers.TensorBoardLogger(cfg.log_path, + version = f'test', log_graph = True, default_hp_metric = False) #---->CSV - csv_logger = pl_loggers.CSVLogger(log_path+str(log_name), - name = version_name, version = f'test', ) - + csv_logger = pl_loggers.CSVLogger(cfg.log_path, + version = f'test', ) + print(f'---->Log dir: {cfg.log_path}') @@ -63,11 +64,11 @@ from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme from pytorch_lightning.callbacks.early_stopping import EarlyStopping -def load_callbacks(cfg): +def load_callbacks(cfg, save_path): Mycallbacks = [] # Make output path - output_path = cfg.log_path + output_path = save_path / 'checkpoints' output_path.mkdir(exist_ok=True, parents=True) early_stop_callback = EarlyStopping( @@ -94,8 +95,9 @@ def load_callbacks(cfg): Mycallbacks.append(progress_bar) if cfg.General.server == 'train' : + # save_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.resume_version}' / last.ckpt Mycallbacks.append(ModelCheckpoint(monitor = 'val_loss', - dirpath = str(cfg.log_path), + dirpath = str(output_path), filename = '{epoch:02d}-{val_loss:.4f}', verbose = True, save_last = True, @@ -103,7 +105,7 @@ def load_callbacks(cfg): mode = 'min', save_weights_only = True)) Mycallbacks.append(ModelCheckpoint(monitor = 'val_auc', - dirpath = str(cfg.log_path), + dirpath = str(output_path), filename = '{epoch:02d}-{val_auc:.4f}', verbose = True, save_last = True, @@ -136,87 +138,3 @@ def convert_labels_for_task(task, label): return label_map[task][label] -#-----> KFOLD LOOP - -class KFoldLoop(Loop): - def __init__(self, num_folds: int, export_path: str) -> None: - super().__init__() - self.num_folds = num_folds - self.current_fold: int = 0 - self.export_path = export_path - - @property - def done(self) -> bool: - return self.current_fold >= self.num_folds - - def connect(self, fit_loop: FitLoop) -> None: - self.fit_loop = fit_loop - - def reset(self) -> None: - """Nothing to reset in this loop.""" - - def on_run_start(self, *args: Any, **kwargs: Any) -> None: - """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the - model.""" - assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) - self.trainer.datamodule.setup_folds(self.num_folds) - self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict()) - - def on_advance_start(self, *args: Any, **kwargs: Any) -> None: - """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance.""" - print(f"STARTING FOLD {self.current_fold}") - assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) - self.trainer.datamodule.setup_fold_index(self.current_fold) - - def advance(self, *args: Any, **kwargs: Any) -> None: - """Used to the run a fitting and testing on the current hold.""" - self._reset_fitting() # requires to reset the tracking stage. - self.fit_loop.run() - - self._reset_testing() # requires to reset the tracking stage. - self.trainer.test_loop.run() - self.current_fold += 1 # increment fold tracking number. - - def on_advance_end(self) -> None: - """Used to save the weights of the current fold and reset the LightningModule and its optimizers.""" - self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt")) - # restore the original weights + optimizers and schedulers. - self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict) - self.trainer.strategy.setup_optimizers(self.trainer) - self.replace(fit_loop=FitLoop) - - def on_run_end(self) -> None: - """Used to compute the performance of the ensemble model on the test set.""" - checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)] - voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths) - voting_model.trainer = self.trainer - # This requires to connect the new model and move it the right device. - self.trainer.strategy.connect(voting_model) - self.trainer.strategy.model_to_device() - self.trainer.test_loop.run() - - def on_save_checkpoint(self) -> Dict[str, int]: - return {"current_fold": self.current_fold} - - def on_load_checkpoint(self, state_dict: Dict) -> None: - self.current_fold = state_dict["current_fold"] - - def _reset_fitting(self) -> None: - self.trainer.reset_train_dataloader() - self.trainer.reset_val_dataloader() - self.trainer.state.fn = TrainerFn.FITTING - self.trainer.training = True - - def _reset_testing(self) -> None: - self.trainer.reset_test_dataloader() - self.trainer.state.fn = TrainerFn.TESTING - self.trainer.testing = True - - def __getattr__(self, key) -> Any: - # requires to be overridden as attributes of the wrapped loop are being accessed. - if key not in self.__dict__: - return getattr(self.fit_loop, key) - return self.__dict__[key] - - def __setstate__(self, state: Dict[str, Any]) -> None: - self.__dict__.update(state) \ No newline at end of file -- GitLab