From 54cf0ab371c16bfc04e88dd53d40476ad8fa3bee Mon Sep 17 00:00:00 2001 From: Ycblue <yuchialan@gmail.com> Date: Fri, 22 Apr 2022 10:31:17 +0200 Subject: [PATCH] First Commit --- .gitignore | 1 + Camelyon/TransMIL.yaml | 4 + DeepGraft/TransMIL.yaml | 50 +++ DeepGraft/TransMIL_dino.yaml | 50 +++ DeepGraft/TransMIL_resnet18_all.yaml | 48 +++ DeepGraft/TransMIL_resnet18_no_other.yaml | 48 +++ DeepGraft/TransMIL_resnet18_no_viral.yaml | 48 +++ DeepGraft/TransMIL_resnet18_tcmr_viral.yaml | 48 +++ DeepGraft/TransMIL_resnet50.yaml | 50 +++ DeepGraft/TransMIL_resnet50_all.yaml | 50 +++ DeepGraft/TransMIL_simple.yaml | 50 +++ .../ND_Crossentropy.cpython-39.pyc | Bin 0 -> 5580 bytes MyLoss/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 981 bytes .../__pycache__/boundary_loss.cpython-39.pyc | Bin 0 -> 9477 bytes MyLoss/__pycache__/dice_loss.cpython-39.pyc | Bin 0 -> 15245 bytes MyLoss/__pycache__/focal_loss.cpython-39.pyc | Bin 0 -> 2868 bytes MyLoss/__pycache__/hausdorff.cpython-39.pyc | Bin 0 -> 4122 bytes .../__pycache__/loss_factory.cpython-39.pyc | Bin 0 -> 2411 bytes MyLoss/__pycache__/lovasz_loss.cpython-39.pyc | Bin 0 -> 2237 bytes MyLoss/loss_factory.py | 3 +- .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 632 bytes .../__pycache__/adafactor.cpython-39.pyc | Bin 0 -> 5478 bytes .../__pycache__/adahessian.cpython-39.pyc | Bin 0 -> 5843 bytes MyOptimizer/__pycache__/adamp.cpython-39.pyc | Bin 0 -> 3204 bytes MyOptimizer/__pycache__/adamw.cpython-39.pyc | Bin 0 -> 3827 bytes .../__pycache__/lookahead.cpython-39.pyc | Bin 0 -> 3134 bytes MyOptimizer/__pycache__/nadam.cpython-39.pyc | Bin 0 -> 3030 bytes .../__pycache__/novograd.cpython-39.pyc | Bin 0 -> 2259 bytes .../__pycache__/nvnovograd.cpython-39.pyc | Bin 0 -> 3708 bytes .../__pycache__/optim_factory.cpython-39.pyc | Bin 0 -> 3377 bytes MyOptimizer/__pycache__/radam.cpython-39.pyc | Bin 0 -> 4051 bytes .../__pycache__/rmsprop_tf.cpython-39.pyc | Bin 0 -> 4619 bytes MyOptimizer/__pycache__/sgdp.cpython-39.pyc | Bin 0 -> 2915 bytes MyOptimizer/lookahead.py | 4 +- MyOptimizer/optim_factory.py | 3 +- datasets/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 193 bytes .../__pycache__/camel_data.cpython-39.pyc | Bin 0 -> 1852 bytes .../camel_dataloader.cpython-39.pyc | Bin 0 -> 2882 bytes .../custom_dataloader.cpython-39.pyc | Bin 0 -> 7729 bytes .../__pycache__/data_interface.cpython-39.pyc | Bin 0 -> 5690 bytes datasets/camel_dataloader.py | 126 +++++++ datasets/custom_dataloader.py | 332 ++++++++++++++++++ datasets/data_interface.py | 81 ++++- models/TransMIL.py | 16 +- models/__init__.py | 2 +- models/__pycache__/TransMIL.cpython-39.pyc | Bin 0 -> 3333 bytes models/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 193 bytes .../model_interface.cpython-39.pyc | Bin 0 -> 10240 bytes .../vision_transformer.cpython-39.pyc | Bin 0 -> 11353 bytes models/model_interface.py | 213 +++++++++-- models/vision_transformer.py | 330 +++++++++++++++++ train.py | 50 ++- utils/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 138 bytes utils/__pycache__/utils.cpython-39.pyc | Bin 0 -> 2832 bytes utils/utils.py | 26 +- 55 files changed, 1569 insertions(+), 64 deletions(-) create mode 100644 .gitignore create mode 100644 DeepGraft/TransMIL.yaml create mode 100644 DeepGraft/TransMIL_dino.yaml create mode 100644 DeepGraft/TransMIL_resnet18_all.yaml create mode 100644 DeepGraft/TransMIL_resnet18_no_other.yaml create mode 100644 DeepGraft/TransMIL_resnet18_no_viral.yaml create mode 100644 DeepGraft/TransMIL_resnet18_tcmr_viral.yaml create mode 100644 DeepGraft/TransMIL_resnet50.yaml create mode 100644 DeepGraft/TransMIL_resnet50_all.yaml create mode 100644 DeepGraft/TransMIL_simple.yaml create mode 100644 MyLoss/__pycache__/ND_Crossentropy.cpython-39.pyc create mode 100644 MyLoss/__pycache__/__init__.cpython-39.pyc create mode 100644 MyLoss/__pycache__/boundary_loss.cpython-39.pyc create mode 100644 MyLoss/__pycache__/dice_loss.cpython-39.pyc create mode 100644 MyLoss/__pycache__/focal_loss.cpython-39.pyc create mode 100644 MyLoss/__pycache__/hausdorff.cpython-39.pyc create mode 100644 MyLoss/__pycache__/loss_factory.cpython-39.pyc create mode 100644 MyLoss/__pycache__/lovasz_loss.cpython-39.pyc create mode 100644 MyOptimizer/__pycache__/__init__.cpython-39.pyc create mode 100644 MyOptimizer/__pycache__/adafactor.cpython-39.pyc create mode 100644 MyOptimizer/__pycache__/adahessian.cpython-39.pyc create mode 100644 MyOptimizer/__pycache__/adamp.cpython-39.pyc create mode 100644 MyOptimizer/__pycache__/adamw.cpython-39.pyc create mode 100644 MyOptimizer/__pycache__/lookahead.cpython-39.pyc create mode 100644 MyOptimizer/__pycache__/nadam.cpython-39.pyc create mode 100644 MyOptimizer/__pycache__/novograd.cpython-39.pyc create mode 100644 MyOptimizer/__pycache__/nvnovograd.cpython-39.pyc create mode 100644 MyOptimizer/__pycache__/optim_factory.cpython-39.pyc create mode 100644 MyOptimizer/__pycache__/radam.cpython-39.pyc create mode 100644 MyOptimizer/__pycache__/rmsprop_tf.cpython-39.pyc create mode 100644 MyOptimizer/__pycache__/sgdp.cpython-39.pyc create mode 100644 datasets/__pycache__/__init__.cpython-39.pyc create mode 100644 datasets/__pycache__/camel_data.cpython-39.pyc create mode 100644 datasets/__pycache__/camel_dataloader.cpython-39.pyc create mode 100644 datasets/__pycache__/custom_dataloader.cpython-39.pyc create mode 100644 datasets/__pycache__/data_interface.cpython-39.pyc create mode 100644 datasets/camel_dataloader.py create mode 100644 datasets/custom_dataloader.py create mode 100644 models/__pycache__/TransMIL.cpython-39.pyc create mode 100644 models/__pycache__/__init__.cpython-39.pyc create mode 100644 models/__pycache__/model_interface.cpython-39.pyc create mode 100644 models/__pycache__/vision_transformer.cpython-39.pyc create mode 100644 models/vision_transformer.py create mode 100644 utils/__pycache__/__init__.cpython-39.pyc create mode 100644 utils/__pycache__/utils.cpython-39.pyc diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4cf8dd1 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +logs/* \ No newline at end of file diff --git a/Camelyon/TransMIL.yaml b/Camelyon/TransMIL.yaml index 6642b70..fba514c 100644 --- a/Camelyon/TransMIL.yaml +++ b/Camelyon/TransMIL.yaml @@ -29,9 +29,13 @@ Data: batch_size: 1 num_workers: 8 + + + Model: name: TransMIL n_classes: 2 + backbone: resnet18 Optimizer: diff --git a/DeepGraft/TransMIL.yaml b/DeepGraft/TransMIL.yaml new file mode 100644 index 0000000..7945828 --- /dev/null +++ b/DeepGraft/TransMIL.yaml @@ -0,0 +1,50 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [0] + epochs: &epoch 200 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_bin.json' + fold: 0 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + + + +Model: + name: TransMIL + n_classes: 6 + backbone: resnet18 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_dino.yaml b/DeepGraft/TransMIL_dino.yaml new file mode 100644 index 0000000..ffe987c --- /dev/null +++ b/DeepGraft/TransMIL_dino.yaml @@ -0,0 +1,50 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [0] + epochs: &epoch 1000 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_bin.json' + fold: 0 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + + + +Model: + name: TransMIL + n_classes: 2 + backbone: dino + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_resnet18_all.yaml b/DeepGraft/TransMIL_resnet18_all.yaml new file mode 100644 index 0000000..8fa5818 --- /dev/null +++ b/DeepGraft/TransMIL_resnet18_all.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_all.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 6 + backbone: resnet18 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_resnet18_no_other.yaml b/DeepGraft/TransMIL_resnet18_no_other.yaml new file mode 100644 index 0000000..95a9bd6 --- /dev/null +++ b/DeepGraft/TransMIL_resnet18_no_other.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [4] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_other.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 5 + backbone: resnet18 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_resnet18_no_viral.yaml b/DeepGraft/TransMIL_resnet18_no_viral.yaml new file mode 100644 index 0000000..155b676 --- /dev/null +++ b/DeepGraft/TransMIL_resnet18_no_viral.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 4 + backbone: resnet18 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml new file mode 100644 index 0000000..e7d9bf0 --- /dev/null +++ b/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 2 + backbone: resnet18 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_resnet50.yaml b/DeepGraft/TransMIL_resnet50.yaml new file mode 100644 index 0000000..d76b2cf --- /dev/null +++ b/DeepGraft/TransMIL_resnet50.yaml @@ -0,0 +1,50 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 32 + multi_gpu_mode: dp + gpus: [1] + epochs: &epoch 200 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_bin.json' + fold: 0 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + + + +Model: + name: TransMIL + n_classes: 2 + backbone: resnet50 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_resnet50_all.yaml b/DeepGraft/TransMIL_resnet50_all.yaml new file mode 100644 index 0000000..eba3a4f --- /dev/null +++ b/DeepGraft/TransMIL_resnet50_all.yaml @@ -0,0 +1,50 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [1] + epochs: &epoch 1000 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: train #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_all.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + + + +Model: + name: TransMIL + n_classes: 6 + backbone: resnet50 + + +Optimizer: + opt: lookahead_radam + lr: 0.0001 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_simple.yaml b/DeepGraft/TransMIL_simple.yaml new file mode 100644 index 0000000..4501f2c --- /dev/null +++ b/DeepGraft/TransMIL_simple.yaml @@ -0,0 +1,50 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [1] + epochs: &epoch 200 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_bin.json' + fold: 0 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + + + +Model: + name: TransMIL + n_classes: 2 + backbone: simple + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/MyLoss/__pycache__/ND_Crossentropy.cpython-39.pyc b/MyLoss/__pycache__/ND_Crossentropy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b39de626d6dab873f33584529f0c3a4bef6bf961 GIT binary patch literal 5580 zcmYe~<>g{vU|={fWp?5<83u;OAPzESVPIfzU|?V<UdO<|kirnfkiwY4l*1Us2&S2G znVA?E8FE>oSQ$Yw%sFhi>{0AsHd78q6sJ2w3QG!W3quNPDq|LN7FRQK6t_D=3R?<$ z3quNfDt9w;6puSY3P%cO3quMgTn(=~Lkd?4cMC%bH&~3XnK_C-g(ro#g&|5Hg)fD_ zg&|7NogqaaMX-e-MG&q=C`BkmxP>7~IE5*gK~tovgUh)nzqmLxucRoypwiDpAu%sS zAtb-R+b6%cSRt_}RUxe?Ki5hjqokyu*h*hNJ+maEG)XTxKUd$^(^=QW+bv2zFE7+D zwM0KDCqGF)H?g=RwMaiNuQV^UM8BjcF*7eSFI_(;zqmL)tu!yWBr`v+Sl`bj9@$#G zf=aHJpfL5*WV|Jol384mn3tRyUs9BqSDcn#lpCL#Qj!dkN5;%hPO$<514Al96k`fQ z6jKUg6mvU68e<Am3Udob6iYh;3qurZFoPz`Ek0x${ak+OaVaP$DEOt8l;;;^D`X@V zE0p9bWF(fQD){*;`1$+!c>1|I26=`=DuNXz!|VfD%f`UK0E%~zSaAgd149kd0)~YQ zwM;dP3m6wNq%cY{)H0_q*0Q89)w0&GmN3;YW-&K2W-}Cd)G#h!sbQ&MUC30+Si_RS zXwFc}Qo{<SSqo)q7#FZEWT<7VVO+qL!d$})R>cHjFJ!D`Dq%0-s9|VkTEJPuSi`gs z6oV{&FBusa7+x}gi2wip|JUTY#gdbsmwt<-II}AC7IRK&-YwSRjKqS}Tb$qkEyyoU zy~UNBpI4HZUYcK8e2b+lGqwB{YjJ5oYEco$N4MD1@{7t7i&8Y%Zm|@n=A_+X&de*g z#a5D7l%860i#xA0H$FKhvA8(3_!d*9!7Zjt!<CG;IO8FL@$oAee%a_}<maXam*f{2 z>R09@=IMh1xws%PIaNOd6sx|TKKj0upfrXhEffhbFfj0gvXlfUuo*cRd6-xjS(x}3 zc^EkuxtKtFE~Y9m)Cko>IGQ6qJ})shH9r0pSA2YKeoAQ$h|LopUs#%$1C?Qqk59=@ zj*kahc8eG4_~g`_ocQ=6Nd^W62q6w)Nr4DZ$Z&%k1&VkMMiyppSYaefkSxg6#d>JT zGKCYIFj-UBS~x)oGmSBYJ%yu%BZ?KANZEoJG`VhZz_MPIHaOMcNS<hE)9>Y+|NsBL z1eH_CF#ADn0I@+y73B0{B~UVDu3?C0C}AvNs$pnm%3`izOku2H0u@xfEQ|~(48aVV z%zkc~jJMc|Qd3HkQ#Dy`F=iKmqM`_t%WiSR$7kkcmc++vvJ?q2FfiO=D^JZ#&nUUY zlbN2EUz8f3nU|7UQKSM2FL20#0}>p3stgPaYOvsAU}R&g5`_6x4`wPx9%E%-U;sNl z3FLUD6owRrR;CojG$u&~NPdHPWdSoJN7OQxuq<G$Va{S($XLr#!k)!Z!@Phqg=ryU z30DnE4MP^U2tx|9G($5ZsMPIcj$x{0u4S!ZUcj>u<asu~B2aG9WGrF_<);(|1_p3w z6z3O}q^4-{7IA=l1`iKN+F{BoxW!UXl%G-rN*=dZOY)17Gj6ez<QHTY@qq$~IWw;W z945C|a#IuYG}*vmS0n^7OBk%6D783>3zV(O@+(qvif=K-7oi0oC`)L8f)5nQ42&X- zLX13&8jMu}@IZtme`dcTMUcs$_+H6gBn@KAfCyO-AqOHraRGLFkvxbCjwCPvwiFy) zpy&pfPz*{xpw!OC0!{+B(>us@$i;99J2<^Fr?7z1dkQGsGa#jPj$0Dp;IxyP0?x;- zkQxY7Bou+ttzR;zNCTx}5C*Y92?2zQL8U?oLkS}&4KH9?$iT=@!&t)*&kRn#jDDId zkQ8eS@*CKzpme9n1POHz8$EDAip)TP3$mDjk&m%T5trMrr&UmD0qJ!H8PWhwtB~|b z>-r#zBa5?!A&V=8rI)dmt%ePh3z|Xo0;^vUDDxD7q7R&yH5rRQUf1M?q+w1_;DFK+ zIMEjIfVjLM0+eElbU>++DKGDq07l$^4Pl3bB|HTeiGZ|<f(S7N28Jk(^5oQbP@NAc zM8RnboUA~}*a{R#ARYtulATe2sY(H7_`{MpqhFB`$Ud;yU;>eNK`{@KL{7XQF^t3u zN><>u8z@PkBwkSafgy?|g(a9lll7JqF6V?9Rs91O)VLHVXoNevYAPfaE2N|rCl_TV zrKTuk=E0j&iA5EeWqSEV>G}mJY5Iod2IhJO#zw~GdIc$IU~`G90Fv_yDseZ|%JXyb zD)kF0!3C49aSFI;25zmHrJ5!hnIxH+7+a)SB%2u<8d+E-8yc7yTO=Bqr=%ttnxvQ* z8JHUCgR~f@fSYlt#d-ym%07lhhF}Mxl-VfxpoxKjp@boeaRIo3NoQKfSj$|)w164f zPGVWekism<P|K3ST+5om0!qCltTl{TY~XfML=EEt_8Qh2HgF4wHHF!np_a9V4NS8Y z>eMhU;DEN1I8#__Siq{77I3Dpg4;=43%Ef!0o+1j^GlyK69i7$Lm4ma(enlmB*n8N zm!>4%;&#i)Pb>*Z%`46?A|kWgVsyL3n+$41$3t7unvm>qiz_d+BtE~iq@c9q7E^jj z6fd~Z6c4E~i_!87www|Q&naR8puED2r4^)vE9He5!CFLqD_OBt1E9=^WGy%of%0V$ zs4VmY*$pbwI2c*jz#_?@su5HegD^-2RJMRiJUdW{$FP7Qg>fNcEn^B3s8&v4u3-Sh zbu&{lqd2H#%&>r^29gGt!EIdT8s-ITHH;}N;8<d<n(3C4SW=RjSFDhzke*qVnx_Eb z<QFLvmnLT@lw>59D3oNRDkNtl=H;d4C?r;77As^HE2L!>6_-HT@*x?i3gDJ5q~ii^ z?Si@wsVSiRmzJNClV4tJ1x~dJ3JMy2IttD@3NAVd9y$u)nhLsh3L4Hj3VsS&3N8v- z3LXkt3gMby#o*MU$pogr?E|DzmAwd5lHOu1DJ{rJy~SFTn3tY<i#a7T_ZE9WYEf=! zNoo<iSk+{?#Rl=yEk@@kE|3r6^NUhai*9j&-3n@%ryxZLC@yZXr$K_H6BHky0-J%A zi%E<L)Cw+gWnf@Hi4{=(gU8AQaQ<7sSi`W8v6i8PVF6POLl!eAhDty!A!bl7rJ2!% zAvUd+sfMA33Djq)VdiJ3Whr5;VQyxuWi4S?z*fV$kg<kUf}w_4grSzTgtdmHnX#4) zEW@^tv4%~8p@v0-p_Z)%u3DO*1}-nnAP%aWnQRyeRcaW688n%yJi%e=9-;@1Dg|2w zjikhq<cxSwcST2`LZ?!v3KH!3c{!B|Y57G8Ntt<xMU@K0sp+|?c_oRUE+p6ta9n7z z7lB$Hx7dpdOH)&;Qg5+Vr55EEL#j1<Pz?bp<Zm(OCRS*&-(qqPDKY?6Zj70?*bDND za}$fRZ?Wg4rb8G7`Ng2VCRhZ-K#vUB<ow)%(vnn=TycC_eo;VbUSdvWRcZ>T-7*~% zMWFhRft8Dqhf#=;gOQDq1567sDlr!MgQ5$Rd5|&4b;MLbd@h;AaNC?+tAfEn;ExtY z#xAJA1XhHj4k6IT1vO?t&0TN~mtX{y<)FT8zXl^i2@|9oUjQn@8M9fM8B>_V8ETos z8B!Qn7+4sZnHd@K1PmF9%?ubA!7!465tIU1Ye4Cb8Pq-pmGs%nMOih>3)n&Ig^UY0 zN*J;@!L<bQ0<IK@N~Q(e(E5UBAwvoaxV~ViWldoP*B87sj9Gl(`eIKF;{tx9`ho>o zUw~+~!U;8u3k0C`1y>3)q`oL&$P&yFN?}Q1>t(EEuVJrYY-X(Gs9}QI%~ZotB3#3i z#h1b^$<WML!vyMea5L0$)^LLJ0#^;^0ugYpfx|By5~C07p^Ojq;5HYe!~kVTO~xWn zBtufHD5SPw&&!WbFG@@Sr;A(6$pxjiSo2DA3o37MfyP|oA#Apk)UwRv)LTp?`L|d~ zDhpCUB|?5;iLnv5b3vejG6AJ^E}#7L;QX|b+{B6^aE{PtU|`VXMr*^}VgU`jpw&uI zY&rSq@x{4O9GQ6q@j3bF#h}rJ;^HV?NG+C{g0q>o9+VqE4M_%OE=DsZP>v9RcHYsd zF-AT{4kj)p2__DvDlyb#4zI$BY(ZHWTty)WPytyKz`($;0%RAcN&;6_AQ6<_1E|*k zk_9JaP|v1@p@gA^F^dsW1lE8%l-!_7iV4)802fo+z*#4`w74Wc7nH#j@^W$%Kuub( zcyWGdQF1C`SS&X+C9^bF4>T<1P?eFGmmXi4n3t|!nwOGV1RBjs$;?aF1ve6Op^XGx zusNBz1v#nkQcX8APZ!oqD~3dvZccu>E{vU*lY?f6iJ@7dQIe5?L8`F{*g!}MV}`^t zC<_;XBb)gaYi3?SX-N^NzgZNG8kV47yv321lLL;+%doIyV3S}hiUlPVks@%n5fquX zgu$&yy}Ud<*l=EAP7x%TAqiN5(j9Mca%MrLUS3LOZen_B5vZgq3I<sXDzJ-8K`c<r zfcuq2p!WGKHec{y4Y;`huHC?S102O*mmrnL;8Xx{xea89#}3p!0e22rSU{yN4<iS& Sh`4}&h>VCe2Oo!^h&BL`v=DLt literal 0 HcmV?d00001 diff --git a/MyLoss/__pycache__/__init__.cpython-39.pyc b/MyLoss/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..119cc4d4a02cb0c100c171e5e23385ccb8309724 GIT binary patch literal 981 zcmYe~<>g{vU|<N_J14P_iGkrUh=Yuo7#J8F7#J9e*Dx?Jq%fo~<}l<kMlmvi*i1Q0 zxy(__xhzpExvWvFxolBvx$IHwxg1d(V0q>o&Rni2u3YXY?p&TI9<VyzDBfJYD85|& zDE?f5D1ltTD8XEzD4|^8DB)a@C=swemK@Psu_!Sxn>9x~S0YLxS29Wx%xBAy%9W0i z2D907WO8MrWOL=B<Z|Vs<Z~6G6mk`#6d4&(ID#28IbSj|FfeE`-r`O!N=+<DjnBz1 zF4knb#pdMVlV4nXizhfgt;8iWITg&~b8(JO%u9)PasjKm#pmJT65^DZmsnH@mKTA_ zd$>RZA*zI2GK)(f0x**_d2VsK!#KCN;mk0jTWrC>AhU{C7#J9Cv3ur+f_S&ML&{Q% zinAfs2)gAbC+0v!Aa-y#7FXtiRl?lp><W<)2uRIK%qgja+i^<(>dKP*f^3KyZr6$e zpZs(%N0aduC)j+DX_`#8_&pL!i&OH8($ZW)Ktf0Yu0bFHO~zZiKKW&d#Z{oN%1x}$ zWWB}bT$Ep2oSIislwVNk=W>f9B)`BLr1+LZcxq;PMoDT4ScNM@1z1K3mrR(^Ek2jb zVz56lt5Q>(UHvqbqIf_t6`z)vT#{c@Sp-UfQM^g{rFkidMU`NW6frX}Fo1)xh!w<Q z0}<>Xf`fs9A&N64GZ_-yMVuf(Pzo>N2C;ZR1TTo-0}=cnLI6Yvf(RiHAq<Kft~5|I zK<p9$8<_zMsv=R4Dlw2CcMdpgAj-r+BK&?X@nGLW!>33Bq*xL}NP!4x5Fx|Bz_5~` zNDjmX5x)%dGxBp&^(%7{^YqK}i?WLg5|dN)Ly8jfihVtO^nJleLq9$~GcU6wK3=b& h@)n0pZhlH>PO2R!aEd{`;$c)^vS8+66kz0G1OVBt93%h$ literal 0 HcmV?d00001 diff --git a/MyLoss/__pycache__/boundary_loss.cpython-39.pyc b/MyLoss/__pycache__/boundary_loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34040291007bdff88adfcea8434647cf1d47f3a2 GIT binary patch literal 9477 zcmYe~<>g{vU|={fWp-k{Is?OF5C<8vFfcGUFfcF_UtwTiNMT4}%wdRPNMTB0&SA`D ziek!Tj$&p6iL>M|<g!GufZ41$thsDaY+yEH4to?w3R?<$3qurVifjsJ3Reqb6jzE| z3U>-m3u6?wJ3|U@3SSFD3STNq7Ed#C6t6o&3V(_~3qy(kRD{o+Aw@7nsD&X#2r9y# zBAg=9!Vo3k&X6LSBG$r?A_kQabZ1BrPmyS0NRfbw2&G7-NVPCT3A;0-NT<lOFr>&p zWkgaKgBdjCUxNJW_mYu;fkBh$7N2ubesOVXUP)1YL8YI|Esl`<0&k!E;$lt4TTFR* znvA!^QZkE667!N%<4cMX^NQ2*i*n;rQ%aIS27rtLVNkHKGB7YWgMuy%6m*O!Of3v0 z3?+;;3=5bRGBh(bGrBM|GuDFnDa;EQYZz0QZ5V18vY1ntQ&>`1dzn($BpI?;Y8d+% zYMDw{7qHbZ)G$dfEM%-@u3@fW&SHnDgQ$^a2xib^_d{_869WSSGXn!dFvuAK3=9nE z3^fd~9JP!o3=^3OS%Mif8H+$_R)T!aq-Su8F}{eMfq?-`{Ib%|$j?m;F3B%4)UV7* z%+oK=FUl@1NK8)E4*~hy*V9Mew-OXK`bqhvc`1oSmGL?G#l?CBmAANTax#lclJj#5 z?6?^i7(RmnrAnzdKdmG;u_8VrHK!o8NY5rGKRGd{*iH|j80?&om5jI8vr|(GQZjRk zK*6fXaf`JmF)uy!7IRK&-Yu5m%&OE|%(;mbx7dnO3sMtHZn2i+7bR!hVot3nxW!yt znyblji?QMsb5TJ_6l+C%Zem3gb837A*k2F=l<aTup*n~Y6igrq1|}&+K1K;fKE@(R z1_lPVWKbZ2ya2)|aT~_Kz)-@lfU$;gA>#t366S@V7-cFls$q0th!v}4u3^YxDPgW* zZe~njOkwI}s%0pY$Wtg`tpV|wdzoAqVue7e7C==kWUOT<w5tKdC5vB35iclFGTve@ zN-aw*Do)j8yTz1OaEq-pFSD>T^%h%VaY<!C>MiEXypkd@1_p*(Y{mISC8;TzEFiIh zTP%qcsl`QnAQ`4a4BvAVm*&Qoq~;ap7YTxV4pPU!B*4hR$iXPaB)}-ZSftFrz>o|| z$RLk{FeoxPpsA|(1|tJQ4O0qZDMOJQC~2oKl`z#XEMNvFZ>DC(S|*V9Ygua;vsh}` zKuLHZL;oB`h7zV4)@H^MrW&?p#%$)Ilp4kb>@{pD%#sW#Ea^-&Y$+@$jIB(P42u|p zK*^o4gd>GDg{_65gcD>Q$h2lAQ08f7ENZJ^tYy!$s$pEfwUD8fv4nL2cMa=8##;6o zhAf^e-W2w1rlL(X><jq7e15PD`vQR!mKw$smW9lW4B-qZj4TW+49(1p40!^E48>Or z7#P7Yl7W$-hM|UCgrSzBh9QNehFyX|grSx*g{6i;0wT`U!H~jO$~b|ssI7!GOR$8Y zogs}eg(HQtg`<QuOQ?n;g(Zax6v?3M4q?}VWMC?};VL;HvJhWE*tMKUX7a#QazSKK zRC3j@OkgZpQp1tLBn)ybCz$4};R4fKCBiivHB6v#wwJk<qeQrd6VBr-5w78a^SEj_ zYB+1S(wKu8H2D+NnHU(j6ciK`lJg5H71E0GbFCCIN=gcft@QQNGfOf`lk}4FbM<{a zopoKj-J<mK@<RPmOZ1a+@{{y)6N^hyi}dsIO7l`n^h=5oGxIX@()B^bLws6kUUErh zeqOPDN@j8@tjGY{l$Tl(pI=&1P+FppTUuPAkd&&Rk))%LtfNq&qfn`#qfn)(sR>q@ zj!==Pkdv5{nxl}LSfHR$oRL@n(PW6MNx?rz0U?ECm_mM$LVjMVLPmaxLTX-eeoAIu zI;NdqD>Ks+aubWQ6*7wz3X1Z}GE-7h6f&XC$Ve<pMX~_JZis`e3KEMFa}^Liv;rFn z6;3Zf;ezrliU7!BD~R>5fH71^%P&$WOU%hkQ7Fi)NX;o$NX$!7FaQf@<{?WU+f-aw znpl)-rI4AX5K>f{s!)=Vnx~LfprcSy0M?(Dr%;}mlLHERh|&~=q|~(hqEv<A(%jrc zP{cuuD@rXXEy{zq((fgx8UdHp;1V3A76KKXph6Q|3rWCgA%PgCT2KkWSi_Xf1gVW6 zB{Qf3TFG1lDrIglX5M1VLI@Xu3g;qF8DtEqrFa+^7^)P~Q%mAY3gXiW;?wf5RVlYv zl1oz(QCtLaET~iinNqC8zyPWhrD_>#7&}018HQG-6vi|rNd^&yFa}13LZ%dkV1|`U zeh}}0oyP=rUKDdheCA5VTP&d1fw&Unz+0R)>8W|C6$M3hPH<<+<8UTSr6%Jo=9J7_ zN}XAx39=F7#v&~cOB+OhxgeJn=|bBiMWCRggIz8l>p^z8L8>y5TdXCe1v#lj{2&QX ziv(Qt++qdi%_0#{O~sy<AD>>7m~xAyEHkzI7IR`w&MoGu%z|4iIr(|%w^*xEi}H(a zv8AMzWhSTIVgc2FMH(QDteJT!sTH?4ijxydN>YpBZ?P7aBqnErs>`IzycBR{rpbMa z3y}qHF{PIjDT17$12TXOlHftEDTuF#VlB=nh_AfGQ5m0~mzt4Za*L^?;1*LF2<Jg! z8yvczfW5_o9KAB2+SCeEUGjozQZ`05CN4$+Mg=ArCN?GxMlMDkMm|OzCJsg(Mjj?E zMlBFzVq@e0K_)&%K1M#KB3Dq&3o-&3gK`pxU#!Bwz>vxi#hAhn#RRT%K+Q{zC~(uU zg&~Tioq>fRiZz%)ll2yxlMASch7#GJHWkP$5H1!1H(@{x(`F`6)0B|`+%9ISIs(q; zkP0BTptK}aAt}EU)XGx;72|Np{Jfk>1yD;5q@pAv6)dZep9U(2Qqzk-Ekn54qSQ33 zb%>8=n5Uy_NU(l!X;E5Ya;h%K21rd}Xkd|)VrgM%l$es7Xku=bYL=XmW@2PwXk?O< zY-VI;VQye<Vrh{Iu3J(;jZbACa}&6Yewxg;Sc^*wQj3b*Kyl0wAD@|*SrQ))%D2U- zIcX~yZ*j)QC+8#<7so?ezn~Z`@_@CQL4^eaTNOVf9Q9!OlR-rqC^A48R9Jxu6mWY# z4jiC0jNsNisGY~y%9O$cYSN@Ir!h-1EMP2Q$^x|%S!$S57{TJ$WLZm?7O<qS)UedB zE@ZA{t6^Kfx{#rky@nkWtyTNrK13vED+P_D#FFHUcu@OTM*-|7P$gKYQw6DmKyg`; zSd^YxVx<6y8HJLf(vl1vNY#a=A6$u{=mj_PtrQB1Qlaskq5$g*fofl<F~qpS4_r;? zRS6{}S5#K%fJt3DFr~=`Zgdw}fbu1zL`ciYPb@Juy2X~7nO9tzdy6GEH8D?<qsR-S z#1<5-Y!GkVVg)<#7E?j;EvA&@Tb#M2IVG6|IjJeP*ppJgO$|uSg%^pp*wgZh$`gxH zVnCS|R8%praWKg-$}#aVaxhi#L1P`NXeHY%j`;Yz#N5>Q_*-1@@wxdar8yurPkek~ zX<`mU2Ap<^KwX1dywIeToSKsZOI~2Ry+KX~Cp1v93<6mWN>Ch(Je=T!nhZ)-$QYEo zK>T7*SqIDeDV*Tk4{9BAMlq*wrLd=Pv@jy&f9_j6!TD(=E}6-xpi;okE#x`K5C(=1 z_Q`Az7MKLFL8So*7aM>RFmj1d!&<`-59)R@#Iw{e)iA`f)-cvE#Ix0~)G)-eL(2;; zE_j;s)8r^}2l)^bs*vEY2SpQWN`8D&`YkR<G6gmNZt)}*6y#LK=jG?+WaiysE6&Z& zFUinkgT!qD$c%Up0m@WGi69oppdxVafddj8hM>%v2MR-Q1hVomu`xnml_;|Jv1eY8 z2_XN2Gw%^lYk)C}p_ZkFA&ap{Erl_iv4p9FIfbc&rG}-Mv6-ovQJkTcHBSg#w3fAo zA&WJIp_H-6r-oqxE2v9Y%UHrz!&t*q!&(FC3bHq|#xU2i)w0)eq%hU6iGaHM9O4X6 zHoG`O4Tm^XOq!t<)Z5`K-UfEZq8iQ=<`f1{uc($YPo#!3g#pBah=JNFwVXAa;64y% z4JS9G`v_^TfT{{`X4B*@0!3^QC^&DifV$^MwJ<oKL3K_MIC#MU4GwaDP!E&~5*y;6 zC}ga>#R)2%<H3y{(AW^91_d{DKtWgpuJl2TJ+!O_&Joq1SdsyyIYu@nHbyBXAto+H zE+z>k6>yyKF^Pb(od`;-K{KBwtDjp)Q7ouKuF2?E6b?!Z;3NPhz$pS0l|`U5mkY86 zR4Q|TI&=T|c)_9=nF^HUKyE2kL#j|y*ua^JDTTQOl&M%!SW&7}_FH@|&hd$PDe+D& zkP-#d^+u@<K*oYFI5*i~<|bxPZYp6(W5{HvVOql22g*yVC2VO7DU2oTnT$1zP%#cz z%M%nmnw*d<l?w7VD8b+2OiWKNN=;8JNd*@fx0sSrz}|$EpP(4fWC6$AEdkI#M?5IM z#b=i%7Nr+QaVDienMLU!1Hr)u4oGkqg3O-+3quA*K1Mc10mdqERF7lNP9Q@-g%v0* z6oXnpH4F<FK$&166DU)G#!70KYM2)?gEI_M3S%}?kshd)6KAMp%i{oLwPwa9#u~O@ zhLy~I;Bk#A_PoRlh2s1?O_m~%w~8`AP6oOAmat!biDzy>PHJvyUP)?-Yf({tktRFX z*F~V3p(p^PG!T?>*;A4s@f8Z<u@$8jm*$jUOQW+u9tW2s;FcOFeR42XiJ^HOmMR&+ zPQ;xYLGd*iWFja<a4_<*6@f%h+GC*L0?C4d3{)S0hm;tym_WlxC7_anDTPszp&2x& zqy`#QsbP|201X53Gt{z_uz-e_Af=NyIG}A93Z-fogBie`0ZrB-X;2WDfe3H`0B(DO z3Or5LTRg?7>7bAU_v>yk7o`>#fwICa_JaIkPzJxno|l>qVxWc7Eiq^d05tSj9G{k7 zl#v1|6_$WJ3JW?RMh-?cMlr^sB5E~e`8-@)LYy-55{tm&jwpdhT%)!eUPnPapx}{M zTAY$!l$Hi*xPm&>Nnjfkki!P94K@OjSX7Z&rk7uou3wOnrf+CzV4`PeU|?)eW~f(? zk_KvNfiPT~pC&UT`xb-Z0hGOpJV7ih`EDI348g5V)+%vS4<m9MsAvI|Tp$eU5QEA( zaPoNrE^aVdpQtTQX3(%5xaG-$-X;aLIu|nZTQD+|u+}gyU|YzrfPEoD3Trk?Gh;1l z32O}tniw0#5Ly$;5E@Gjvjjs5n=q*D$_k=u*g&K>LoG)M(*lkfj)jc195ozSoHguO zoY~Apmuff`aDmth8R0EukV@_p_8RsYj)lzNs*YzNLoHVgS1^Mn$FBl-DC2Ak8&S0# z3~jxFD_CfI>VbU`sHUyrPXf1RD=MpW?LbUTUPx&SZ_O6@f}-6JM1ZozE#~Bc(p#)~ zrMaL{bI>4LJeXZ%4l1OWO7e^RKuW>YFgJR!25L4I*)cFMM6nkn$Ag;UQS9j@5C&I4 zQEEzjDyTfX#R8Hl3IUnOo`JO)46d8Eff63LZf50Rlwy=(<YE+IG-Bpq6kr1NSw$dC zVllKN2`vY47l5D?zZPUPsEFfW<l!g+iBPF77J-(^9xjk#0aC(~Un8e5marn#z-%S# zD77@WMlJ%)=)kJu8c_6r^E0SELaTqlKE1`1kpiuVi!wox#hF2{?!5#GFHoL_)V(5% zRkGMTioHAlmF=Lg2A2n*gojiEGuMD;D{7b^wQ&}t{?(~ru3-T;YZ3J?YcNAm7O3?H zt#fNZ9tHa`8zc&Cp(E<p9FRPyc~=B#e4*7B;PiC^<SAI?z{3cwS7oqz44SGm3DvHk z;J*km7vy1Z?Fte>sTe>F5>SwUgQE>xJT73Y!Ct$jVXs|lSxZ1Ide()EHLMa0HOwLm zwXBdDx0Veo!?uvIhE0N@hDC&-maPV^TAHB-E-%f1T03hpRe6F7F82^U@bH&{Eo#rX zLI>XY29L4etC_$>ktQ3?s=f#`%A?7Ci^)BtC?6D2pjy7D0K_T;B}q_y530(eK=nO{ zffj$aWTDk1NDf@z2c+gD=44i-rhxjDk3ca1>Z>!b@-RZGeGV`!#H7SnRFAFt2gMht z@+t<6Q^NbRj8V)f?BM1%a|%lfXA}!$<|m33Jfg)0D(-IyxnvfXKyo{*Ut6^X9Fiaj z(9A4oFhZdK=0F9IDsbfvR^*#sl3A8mlA5BBREZdj$}dRGD@p~8IO>6BEYk9eN^|Wq zQ*50KQ$3T*f=x`rz?$JU_$C%8fPAh14u_ITg^<*uT+oo0LU3wsVqQsRvO+LY(+boU z0Aa8p&p&27?%4X{iCr?dmIs9b0|Nty4XV{axEM4{4w_P{VSo+FfjivHpbmj1qn{=V zq%qL|ia>DT1M0D9GC?W`5F1+$`6norz~Row$5<r|@(XG(VlPfW#SjRC^Uo<p(1a(X zc~Qa$8ZYc$#>h~@44#!^EMcx;%3=Y}>Uq^LEno%lKrPo4rgSDSn<<4E)Q5%6>@k&q z=K5G$7)sbd`avq-GkZBTOts8;HZ@EOI2JP0GL<kc-~_1wPk%FJab<C*uw^q9&8lHu zz*7U_^MYlV7x1Mp)i9+nLFe|EQRepcpfpUFQkX%mNnuQ32M^0}moR35W(ZlJ{?cG% zC=sY(Ss=KOVS&&>h7=A^6DWlfR7{35q##TdGGv&*SS*2Tb_wGGVemX8YYl4_7s!3t z%o7-k{AySih=6zt85h881C6AB$`iH{#s#9FSx~T>m>_PdVHIIW;abR4%a+0o9WM|Q zX8@~Ygt{K85~K<&4jKby1`jxbeFYK+k12?WGf0D)VC)OTA;XB^Q39T#5>OHXwVlC( z1b!$bCnza`(hI211x}LbpqUc9GdD%5AoD;)E~Eqj_a?y&R7M5{@U%mf9+rd$E$*?+ z-+%}A^@>VC=7Ku^n#^FAf}2|}K`sQdG$kPIEKng-1a4h{XBWU79wU%e&|FWE35aD1 zBEa1tP#X&}&IlS*E3yPhfO@QuHX5kdDk=ksm4k>15K##tK&=bNSQNNfhBOWdZj2Q< zf=kr&lK7I;+=8MikQVUh2q+jKp#lycaL|B83IBrfC1?hMft7;^)aFuw&J^%~`=2^Y zpq7{g6BiR7qX-iRqa2e569-e36z*WuWGrd~)zR!l5g^MWK|~a|^%ozXlAjzO4;hjJ z*<BP3k_5X4Oo00%V3&c)7SOtlVo=@5!N|hO$5I3m(`0nhWcKrO(-eiw*YNwf#3L`7 zDFVj{DA^Q&`n6HK#mSikm3nz8nYoGSsYM{uiogLX4YC=uz5+Zm1D^4?#pVkhQ3a1@ zK-yBE&N8?I32sXjfogYfg#va9QfPu*4|4b|4jag@tQ{y-7t1j)FmNymfaY2lIaoM2 WxVX8@grtSkgakP_g;azDIRpX4*)HM$ literal 0 HcmV?d00001 diff --git a/MyLoss/__pycache__/dice_loss.cpython-39.pyc b/MyLoss/__pycache__/dice_loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23837a13153c252543d3e48776b1fbdea5eff0a2 GIT binary patch literal 15245 zcmYe~<>g{vU|={fWp<*08w0~*5C<8vFfcGUFfcF_`!F&vq%cG=q%fv1<uFDurZA_l z<S^wjM=|HJM6u+uMzJ!2<XCeUbJ?QUz-+c0_FRrA4ltWNhcTBkiWAIc%HfLQPT@%5 zY+;DvNzqK<PT^@`jN(nvO5si6YhjGyb7x55PZ4NgND)Y7%i?cljuJ=_Oc82fh!S*X zND)pEX<<kafyxNEGo*;7h_x`Jh(Se!-5FBEQzTj#QY4@vBJK<+k||Ox3@K7j5m9%B z6zLS17KRiVsEC+5LyByQTnj^r98^Tyogqa&MWKZuMFA=z;m(kvn4;9ekfH<?k#uKB zQBF~5VMtMdib%OLq^PE-wJ@ZpK}Dq98B)|!G+G!^G@v3fDNMl(n%Y%8T<NJL@g)WE zX$A3Vc{&Qg`DrCCnaQa>`NhRL3W<3s3NFs^iFqmU&aV0)`32tbIr+uK3W-Ij3TZ|8 zxe6I2B?ZM+`ugdaB^jkjddc~@`o5mdx-Q;sQTlm#p?;|)`bjzYN&2~o#U-gl`gwV! zd8sA(B}Ivud6{|X`XIyN(@OJ_OEUBGin(5bV%_g0BLf42Ci5*m=c4@L;?%s7qWpqN zKbKn^Ap1e?x+M{wnwg$al9~cm;R;a!lF?+m#gvz)$#{z+EU_pvF)1filkpZ?YGz(> zX>KyeSdaxE3`$b03=9m;pk$TCzyL}%Of3v03?+;;3=5bRGBh(bGrBM|GuDFnDa;EQ zYZz0QZ5V18vY0{1E`_z1DTPgvA&aGkv5%pasf2X_TMa`ElLW&;##-hY<{IWKc9=Sd z8fk`L22FN96lX9oFfcGPFfaszoFTx#z>v;R!w}0+%b3D2k*Sa+m_d`V2&85ugC^50 zCOv~&jPXUB3=9ll;+L6zMt*K;a7liVp?+mfVxE3^eo=ODL1J>Men?SbUa_yIkG^ju zC|vYYGLuumQKwf>d5g;?C$qRDIX}0+j*o$X;WH>es+5ZJ(@Js^E8;U!a|%+6^lWnS zlM{1_?eq|e!A=QT$#{!BJ2kZ+B{R1O6r`FQw^)l3^U_mqG3TV_-C`-utV+GboSRs2 zi>)ZNAT_b%7Hdg<QF6vD=G2OUTg=6!xtc7u7%Ofu7ZsG;Vy%b=sbo%#uK@cALhvv! zFx=upbr49b7+g*;Nip&<N-**<7RfR&Ft{ay0ubZ{5JrjDFa`#O5{3ngHH-@x7ciAD zF9gLVQ;|^(qYFc<SS@o6Ll#R3a}9GdBd9>>WvXQ;l*m&kVXXo2nR}UB7-EG$sun<1 zEo7`^D732q#UqPfND)6MnKIsDFG?*-Eh<jcWV^+bS8$81G%vHTH1!r+VsS}jLFz5$ z%)F8!36O`_it~#~Qd2ZpKw<^ASQ0Bzi;F;!ev2s)!}naprMdAXsd>fuMIxZY0g5OF zCILnkMh-?XCILnP#v*kF28Lu%A_jRJgt-|Q7&xFQs`v&Y149i{3S%ilksK&lr!bW; z)i5kz1}AN%X2x13koRj@YZ$XwYS}=^cOgUn97cu`rW)2}#uBC)wr0j`=Ax7u#s%y( zY$?o=3@I$>Of_sNEGdkwOp*+X7=u7bow0-?g*An(g`tEKWFE-0W+qU^X=W^Ht6{8V z&$FswT)?%Ep_Z|Pbpdw`>q5p__8NvPo-E!J_H3r2O*QNb_`rOAunhYGffSY+#uS!? z%!~}-3@MB(3@i-I%!~|q0)`C5R}2^!!7!46k)ej6hFyfAmZOFtg{6jFf<c6#mNSK= zhCu=%&eg$?!dS{Ufw8Eqgf&aBgrS`wjS*bpbCj@V3Dt0<u%vK-A{ms`A?#X^3``|= z3TF$5FI2+`k%jmQ!mi~+GLr|ck_#e>qLQnIWddW-k{XT_CSj0kIl(k%4HuZ^DiN;X zs9^%tI=#%b93{dvoNyjziEs@UoX1teQNvlomBt**pvj-Omyv;iOF=<FAuqKgKEJf2 zptM9Gx3suKAt_ZsBS}XgSx2ElN1;+fN1;knQxmK*9ibvoAtx~@HAf*gu|PqiI3uwD zqR9|hlY)Pc0zwMOFopagh5Wo!g^c_Xh19&{{FKbRbWA(JR%WIt<R%tpD`XZc6cpu` zWu~O2C}cvNk&#%Iiev$b-4F*`6(klV<|6sf3T!A;IK2dg3(8F>0w9a6AlAbI#!w+G zzeu4hF()%cp&+v&HK!O{!x?~uGxLxokZme1EKMv*wNl7TQwS+4O;sq#NX=77E6`CW zDFEwF%Tp-N%*g=-Jw$1WLQ-m4eo?AIacORDBB*YI7*~{9Qd*P;ai!l&P{9r^LBZuJ zO1%Io(m+KSxL%Ne)e8bKOtqkV%UHvd%>=0*Amt?|0|UcK<|0t(c#ARf7GoAdxCm5k z7V$DLFqnes4Nx&srGQ-JV=L%yu_TwKB%-(o<XBKyV+<;clo%L5^^{aCV+~^mLkdF* zLn~7XV;Ykrg9t+y10zErQwl>c!%8MUi1)zGV}dx3xgtJuCF3m?Q0zcl33A{qPMh@9 zywr+<B0CqjGv#qO6Q)v=@fLGRCThJ-hBJ$FK{kTiSfmGH8Gr~d7v!=cBL)TrSC9}W z=qL>mBJ6U5+2sbQU_@@QmXsFcq!tN+B!obOFsL?R1?SBoF%X+QFF!uLC^6+0OIc=W z`7P$eoSa+CRhb31SaS07(r>X=r55EE-(pKiEz3+!y~P5mmx^>i8d)>*Qc^2!aTF&f zmXxFx#ouBrE=f$z232B7nRzMTDn^t078jzBxy6)TQltuUg(1iQHb{cM#af(E5MKeI zDsORA#^>jyX5^RLVk#-P#gqoZd63u!hb|~!Z}A{UuRJJOS%E4s9#H+p#>mFR#VEk2 zz$C-O#>BzM#mK|R$Ed@^!N|kN!^Fj?1%ga$j2s}y#K*|T$j4OV4XO`8Mj&HQE&}n3 z)fgBUQW>HcQy8L{QrN+D0dopV3uhEFr2dLxX=h+zh+++9(B!zq=?-fH`#t}d@wj8_ zk0*90?F>)`h2dg3a61Fka%^S-wH+B5Y8Y!6;+bohY8c{KQW%37G?@}jU<Ed)&~{JF zOD#&w$*f9EQ2^Ow4K5!*GS2x0mEfke6}Gmuk7t;tqiaa8esO70T4HjlE~o(wX<-{0 zSR|!bT38w-rX(ktn46`VC8wmB7?~IvnIt8f8JSs_8<?9|TBL&8+^NNS1(nJ^MwW1U z3la-bix74u7FA@H>E#!t>ldV?=^L6GnClrB8yTDH6{MuW^^~M3B<H83B6K130{k-5 zLcI0;GSf;b;hlh#)H3}%kO-&`kds)MS_F1Xd~s@eZfaf$xS)i%%N*e_KTXzKti`1T zsYOLTpcKauAD@|*SrQ+Aizl(5Ag3}uFF!9QGw&8#ac+KoNybX%TP($?IcY@!pfnl? zDi|4WamL3d=Oh*v$3t7sp!8Sd&%nSC3QBftp!C7c$Hc~1B?OOmJ(x8p(GF{2fx8V} zpcWQm7DFv_4O13FElUX_DDuH=C~%{)gt>+V+=g<iVOqda!;-=%$&kX7&Q!ya!UQTY z;cY0U5>`-Gp@pG@t%fm+0c0Ay4V6*DRLh!YRl~G^eIclu!MK2<2Go!Rwc9wexKdcN znTn>=fLd0JS)5rs3mI!!YgiZXrZCkor7%I;Pz)$-sBNfis4Tt|hEm3&D>V!Y_&^Qi zTE-Nn68;pn8rEi}1zaVJSpo|gQ`p5BQaFSeq#0`2YS>EnQ#flFYgn6^MHp(?!F;wF z#%8!KaZpQ)BZaGkKZU!7y&2Tb6a<NZ`5Z95G(#<C4d()(g$%V^HC(|Anmm4x+5%LB zqm;9td<M$KBvgh)nxNR{gcINh2UUWZAR$mmUZskN>a_f#^2DMPY&9Xcn5^PVO0KA^ z(zQ!UPT#%v`px?_w{43w7#J9;1i%tH3SbH(|MK7e|NpC4lah7qk~D>jz;!ODh%Pb% zSqv_{%|R>+5Mc=-K!skBHHc*cB5XkfsBA8>1F`Hu1UM;!8uqu?({l0?ON@<*KneR6 zOKxgno+ft@C}4`5Kq^7$vItyTfvXXBP=;YFzr~Z8R|4vyfV<hZSWELV^YfzEQ!>E~ zvRlmQDanw+8604su)4(#4U<Gr1_61Mfe}ZU%*@5e#Vp4tz$n3}!pOl?C5REOnoNF0 zzMzcAc8eoEJ})shH9r0pSA2YKeoAQ$h|LopUs#%$1Cc2L<;$X2kjr_Y*)};fCkK{w z!A=YT`4!|&aFGDw6@wB82O}R3SOg{cfMh^f927RipwO;iSiq3N2<i-`FqJTtFr_d< z`$bI6jN%NSe$fJ!8m5Je3s@I|@;37VHgJCgG+x45HPbBzR8{5`D<mqUXO^YrDL^`D z3dN<#844vCi6x-sp+a&-VqRWqjzVHZW-+Mso0eHrT%rdqIYTm_&0febN(rc^o|a#f zo0_7KSgeqipOcecUJPj?Dkvyu_~|G(>nOPBD0t{7glj73+9_x_>nQjsXeqcTXeoFo zXeoq4TFc;8q$U%Xg5)GnyuJMY|NnnY_99T5p~wUjZs2wVs2O*Qy&$zHx3nbn7FTk9 zUP)$pX?|&O5h!_TvfN^Wc<UCU^DQorC*t#qQc{a<ae<v%kYAjdf|+C>A=3m(GN9DT zz{<rW#>B@|ltsOYnH!$E!i*rP3Z;Ao6%{Z{T(xXPjcU0(KPRtJzn~I4&Y^3Z0v_Z7 zk9(P=nkE{VB$=5QTclYen;9D#Sy(0;8kiYdBpRBhq$V1gq?i~Pm>TMXv>2y=>qBVe zY;FopE}%xiV6L2t{6L8+8$^I}F1WA)6;DM$ATdzpgrpR3eg!F+07@XB9LWHwbVZQj z6nnJ`(h5pZ;4B<9Osm}z{u;(+##*)#{u<V1&>#u>0)d5~3YL9=U<%tp##)XNp%nHS zjv5XThAiP4hAfd3j$TG-h7?ZFFh(tB4I5}U1fsTvQv_VCE)WD)uWU8!;tU8Guqby5 zcMX>aLk%~wS@f!gmw=K6xEfYRj(ljPgR>$oN@ieScn)f=F))0v2lfAQGxOXu8E^3x zr52WE7Nr)0yFuv?6$}S7MX6Z>gNm_Rtl%0LJfK&^4yvSBlXDVt3&0h!CJ%wCm^Cx6 zptR%`TS;P3dTL1&7le*4F3r8g3u&O1g9o;YZ*jtF>|5L^sd@RinR$sN`9+YN1g--? zc>-Kb&H-g7P|p=_70m?hDlu{~$}tHsaxjT8RS9DWd016l6bFhRusvV`oQuIWf--tP z$VO1L3a)TLBA|GI20DZS4HbY`#h^L}-tFQ9ce_Btg`jR1R|<OyM++lT$BX+GTW~OF zyxPw#q$meeYk-RhkU<O#3?MeVa0j)gK*O3f4Ddb|E4a_a4C-@%dtR*IUKm>qOASLj zdkUyW#sQx6V6OTMFIj_A^NKS|GRrbcDs_VkQj;^&GD|8IK)ous5~RY{N?*Ucyj-s+ zKRvamBr#VnIZ@xKD8D#4Bi`53+1b%E-pIhvl(H^Tnn9X{nYppCv8j1lO0uO<Qc|*^ zrD3u~T57UEN^+WsWn!YKiKQiR{U~EgBe*xxXUznGo%Vj3oRBiR3>3$p0=Wn@=5mWQ zB|knX{T5eJVo7pFJZQoI(TBRlShSK2lDNU;dj&{qIfww4^dS93;4}hGE#QO#GWZ}U zp@6bEq@-s9wILbV7_0apv5Y9}L5T?(%M74qI=rx72P^DZK)s?`R+K&#OEY5$GpMl4 zW-fyCu~^}KELIZxSZrv0ES80gwV>V@qL0P4fV&3P$6~8tBdw2B^dN;Xow0<!L?DF? z(l2a=v^LrEgwRF7<Ku#$UfBf3B90n{1%j}CSq+;6Xat$9hEakcg-w{Dma{~thO>rE zgaOpgDiH!z0pR|XusEo%#s!w;fXi}#RMjwKL1ei>RS2jGNCEZGL?jq$xJ4Mm8EScI z7_vksFcvSWVaVc}z*y8*!vhP;TAn<S8Xo9)AGi+-8uP2=so_~5wh*L~2Ry|EX?@bP zA9o9s_CP7CN(q{l2-Y2-l1NjWN)-Yq1KeT(O`}1^<-oZDG$8}72#P8}&I9L<2+$-E zFC^>0yJ)vKK_yH)cnlP^w+0?si{ea5hYp}c@j?dJQsPsKit>x11d1VrRB~oXWqd(m zQAw0QF`@_qi`-%=F2+_*Tm)rCP`?vgPcUN}USngFViIEFVdP>44YbKI$}vfR>kU38 z5m48T4>hPYSz*Prn<k?l?kWRR&Kv+)4$9RWp!u8sd;&!vQL0rNJjm15FjpmmDhybu z0%F4}56~baX_W_P2(QW(UYH`)382|;w9zs%12a7X6LU+{(K0_x4oJCM2TB~E;t^c# z7J-uqC_N#H+oCFv5-bHO10w@NF{pTDU<Hj=K?>9=QDom^FH}KO1E3%Px1f)JhVMZG zS1dIQS&T($;G&7CggJ$&gatHY1udFb^Mp#!L~B`V7_wMFMN^Sa4Z{LfNWlaewyI&O zVXXnpsIoV+#xU2i)w0)eq%eVtCWb5yaRw-xU7VqYLmV`V&H*l>K$C`?#oNH{SX9HA z!VDQ0tL4lSso{hcG+;4sK?52v0!<}y)^KuzCKy<tLq?#i4es}@<c8#LaBx<G0ukKX z1&_OcvnXhc7!vH@iVGa};4I1o$)b{=CLLpCQ4L5hIBSB(x;;Q#PY~e+Qow{UU;=K5 z@q%JU29)9$*_a@CiHnhoNrFj*2^3S%TqJ@LThL;Q*t`TvJO9!05^oVmlxlg2-7`Ow zxHcG>*(d>5Hp<D&%hoF{D9X%BPb~tkB}vXmEC8(=F*MLKG%zsNw=}oVHP$sYwA3{< zurM~&HH<ehf|~|sz;jwND8|UkX=<Rj1m`q1q?{%I4I3gd*&aG&vKqD;a0Ua_{a7>C zI&kKiMQY|Mss{yL1E~5!${R(EAW{6egE=$56k7!ZngJ=sl0BGZkh2FrEQm?W8KAgU z1w|TCEyGs?5~W(s;0^&T^vDLUs3Ba*fQEBPt7K{z;z4C5Lp)~+Xp*0W=nBUiRN<JK z8yTYXJ)Rqb!DajOG%$E(@2ANHse)QTv4}Tg++t14Dac5?#gdd-l8CcD@&g4pXbB5) zeFQ4&szi{&iiq5Dg8sRMA&V0<U;?g;xIlv{U>2fU;x66=&N<77$~im>!SflQY6+Zk zG<iv@iP}Li(g7l{)=9+VJN6Pt^2C<+VnGpwE$?X{=RINMNFy=-fs&Lj$QGph$6o{z z#mIlK-d?d5@|-MV6f?@aEDLmAmK8iN%Lbm86?6kFN`tx*yq>LU3%J<F-ypLvG|)3J zw=gk+=a0nFl8pQ!+|!qNiAlwYNkxeniMje|AUk!TfvXD|sD;e98k?A>7+V?|nkO5X zn3@=v7#O6arCJ!I7#f(R7^NkF$Rtb4R72!JT{Fm_E_4j*g+0dTH7IOA?Nso<30m<1 znoedeVM$}iWT;_W!q^8I(qILxw9sVq(`19>>@H9e0_E&mtm%olxryKyDCz_i<ILbD zP81*1C-K?kiACwfMLi%LSlXJ^pvVNZ85tP)K%I6ejIc-KYmkdU?H&*Y4gG*HxTy(R zDFPaLSpb?6XIjW4&QQw?>cfMZip&d{YFTPnAS*>c-9=3nuuCC}8kh_6%fXJ*WCc4M zGQ(BGz`zj21M>o?8x_S1j-Ys`7`D{kfbfa{BOhaxB&JtjshzQ?8#GnLUK9@smIM$1 znghGV9v`2QpBx{Ln4T(10!e}gJir7v2th4|B2XNZ!K~z9WMScB1&d*11W+)6f(#tl zB#e@AI2Kpt64yqfXq+q*oVT#dKxC$-rdAZ><QJvtftcWh0>veXxdr-QQZK8xz&^1k zuQWF)waC`O)ZEw{X+|G9Jo^JN5;YMNlAuHjX_|veKyVWt+-3*2)WJm~u5M8~D6m1J z0jSkbl^`sf5GfUuzd%6_!l2oFPyz?1)Ccscff%yDD*-{Xjt~~KBLr&4af&n4u+?xv zYNHh98cxVsMsUr<30B8d!;r;Qd<k4L?W^H}cbK^HL~6J&I!s(OTno6tQvh6unu&+# znh6}};NS;G3Am#Koo)tol!&R6*dQa?Q7nnYmATmR=2TGJfmYBkpmv%VB^XiaCINU{ z!m=un!w;au)dsQzsbb<Q0*PYeM_A%OT`a@~9<^plVFoW2Vo70bVL%$QX1~P;Um*{! zk-&2U$uOfp=?ugMVNfYqjHOw~3>vtGmg`Kga=nBtjUk1xgguiHUe<Gf%6ibMI#9NN zjp9y*<(*rciRtM@sp*L&sYM{0ZZRdNf_sMGWzt2UY1x%5;F9!~0BC77c!(ET2}E%w zr$U)UQ$YrT3o~%B1}@P+c@s1gTMP<ANPQr{SS60?aYQBqSr76!2!k>h$n)S#1{!_> zO~2PL)-Ww(0<W!Rgt9?n5lktJ*-S+`H7qsE;taK{c^sfIiDt$o#v0aOhLy~I;GnHy z&r8fuD9+DY$pZFN(KL{QL9V_f?3Z8SnOl&P3R+H-n&Ju?GJ}m<g9qP0wZ|>?lw@#3 z-C~EXi3crIFD}g~!InI?fKm=93>ZL5&lo}dc@D-ZF*KjUk|ZP8g}8^pK#{c;WFja; zI2ifZia;V5=@1rvsPzJ46f-;}azwGDKvE*2dBk>0ARsj_F{h*wKC@E=Sy+bFI?69d z%_~X;Efv%Qt;9;pFDlKo&rGrPFw6<buqe&*_5+vu$)GA36onuRVuKtFa#}HHeF|#Z z1XRbzgWD!4OeL&HX_gI?W;L1ppoKn*19+56lO57pSpbSUP~y79m<?XJ0ZFSdpymn_ zIITwUq=Tj+V2Kk;x;zL9Pf#j@q{}Kv^gzI#E<trK2!qolXmKD)x?~1<D2p+Lv4$D6 zupC_Uf}1Kuy%4v7oLICF#KvfyaHK;MSQK{wBtGIXSJ;3DjE*6k&ci6iR3(AobZEjV z0!_IgB~NfLfKp`<C<H-k!HYp5z`@AE%)<&6!$?!0$OA<rX=R%LG@+N|7i5Do4LE_* zu5>G!0t!;-1l1x~>VT9<ARj<7J7~NMl2SmWTG4cnI&hGnmRRRO!HZI2NuqlPd)fen zCCD~#+MuM&Dw+v$+AI(Onx;jHL2%@ORykb-ITFM}DPSZp912YdnuLlNP@tUynF;a$ zxQGFX;7$%C_rbVbD++w_)4{2sst;UJK+6|n7lmM?rJM@onI#ztt`!COd8v6NnTa_H zKKbd1MVTcTxtYldpyn%h-MvR<dPYvALTFx6VoqXSa%zf#e^ORza*0B4W>so2*nHS7 zGqiDb3j<3%0|NtdGkCYIs0}n}4DvZBse>?x4Z`4Z3$yMmVFvf6V6`qw2`jkPWkV{# z*g<`jcyQf|k~$ZIVii<&6@iM8Tg>Gc)h(p=0;zM0W`H^_T;+&*7FWOPDJb$_saAkd zgt1BlDa3InSWwyrCs@#6HKH)%2CZQOZ-N3>xhyG+B}@yLOPCh0lrSw|tzk}KTFA7J zksCB5&$f_Bgdv5wh6Owi$O@`-HCg<i6UQ&@(|g};UGV+b20N(8NqbFpNDaIM6cpfP zj(wovaps)-^r8Zg7|4&B(1~Ms1za>2q=*9)6yW|FH@GVSNl@5w$ZJr5KynBlBM-X* zBinx#<|<+2KtW^<Py)bRT7u&20mwK|4hQ$l*or`6nv8Cm%zl1unzBWp)k%;JD8HXe zJa~g4bPFQ5RaFE^OyC$V0+|uT2VV1~msnbo4_bi_UK3ITioIK`d8N4pl|}NPaF7KN z;ORlo1jH>iU+~gLNR0<tR#F67L;;=>gS6N{!<LYqJE$)SZmfb^3gDu>2$Z40&PEC? zQ0#;Edfei$fh?@FV+73;fU0U177j)caPx(Sk%NhYgNvKXR7j75SIAJvUPw(yLr7Cd NKuBFkKuAzX2mli7_Sygd literal 0 HcmV?d00001 diff --git a/MyLoss/__pycache__/focal_loss.cpython-39.pyc b/MyLoss/__pycache__/focal_loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b2734f4edcae0885494fb2002e9678f127ccc8d GIT binary patch literal 2868 zcmYe~<>g{vU|={fWp?5hUIvE8APzESVPIfzU|?V<c41&(NMT4}%wdRP1k+4WOkkQh zmnDjY5hBN$%fiIK$dJnx#m>m!&XB^C!ra1;!ko&I#nH?h#hJnw%%I8g5@fz#GKfUR zY*0?I76SuAD%h4NrWD2~<`nLBhBU?$_7sj5&M1}?&J?Z|hA7r{1{Q`WwqOQLo?D!5 z`N@enKKaGPiCN4H3|tBd3JS^j1(gbEMftf_3K=CO1;tkS`stY^8Kp^j$@#hZ9>t{< zrTQQp@ge#7IZpW%x&f6T`9;YY`bjzYN&2~o#U-gl`fzg*EWLtCu=ybwnZ*j3#R`cE znYjfysky0nC5a`O`FRTYX$oL%3Lw)J$}>wc6pC~6^Gh-mauSnLa}<(`@{5ZVQu9iR zKrSpUEhxw@DoIUID9^}D&H$NMP?TSgU!0nvkeLUzNgc_!wEQ9km(;xCR0aQ}tkmQZ z1((#4)MSvodI}mae<l`HWR~gW7p3bbCKc-&nj2W?85kKD7-_15%>#uF*vNR03vCs2 z4Yd+;3NjM4Gz@hMN;I`{^3yd6N;JW0tqKy05_1*uN^|3ra}tY-t&l~*x~vp5N>cNR z^NTbUj9nD+ixf;;6iPBu6^fG+a}tXb(h`$P@{1He?k)j^SaMNjNorAMKE!KK<I@v! za}%u;G}3bN6H9bb@=KF)QZ*I8q6&5j1`0)~DW%D&#b9%bQgaeZGRsmGazMcdGPgW6 zCr20Ty3Dl9)D(r(ibPP16f0;H*cs@VYAO_zmXu`Xr7Ps-7o~#T4+;TfoR*(lTC9+t zr;w3Yl%kNES&V5KSRd56km$s4FGN_OEHS4vRiQj1H4inakxfWS%t_2kPL0pZOG&M; zQqaiEE74Sd3MzmtR{%@tC=_Spm*%7>B&8}87o;X<re!8Wj0T52D7ApY4B71B%&OG* z#In?)#Pn1v1&yTq{2U#H`~pxKNzBnyaH>>DNli;E%_-3VhbJf)Q;QW6i&7P!no<<< z%TkLJQWKLiAc33;N!|*Xd0-t$i6zMydSEO4UNSK-FcdK|Ffcs-nDMw{>yIaPA<68} zj0~bcY-R=q1`sZeVPIe=VJKm&VQ6Mrz_gHok)ejMh9RE0hN*@jo~4Glh9RD{hNXrf zo~?$ph9RE4hOLGno}-2#i#>&*l(8rR%wo)DC=LL#n6g=lB1$+@n41}!7*iNPMT;NA zEg+j27(Uon$)iLHA}}-a6rv4u3}ZF9Zm|}Z7Niyxfjn}HBR)PeFS8^*{uWPSK|xMs zd|rNDPG;UMR&a3NVg)D3TfESif<)gfHi#X!c#uQ>7FQT39k>=1<ritP-(o3F%}Fcb z00kB&hyazjMcg134~XDpU|?9uc#AV0lFQ=bS2FxE*U!k$O${!|FEZ4x%t_4CFV8Q^ zE-pw+PSp=7O3W+v_4LvAtpsIZ{WMVV7Y~X_y@JXjP>_p(N^B8OFfj5lvLPT77o!Mc zl@K`lgS6_wJe&-QAdq?x2008=8asnb5ny6qsA0%r07rpH4Z{M48pef;wTv~4S&S)+ zk_@%XHH@IhOaVnlMh)Wv5T7N5c_CvB;{uix#)XV2EDM=I>KCx4FsCr3FfL?Xz_yS9 zEW(_^0#?PckZ~bXEn^A$0*)Gx3Z`178m0xzDIonIHbWL?ElUkc3PUMF(VP<Y1za`E zDIgszc{L^6H7q4OCA?XDX-v&bMR!X0Ygn2YYnZc{iuZUiG=a<%s9{+Mau-xj4NDEn z0>On0(hRjMc{X7ELd{GQ7>fjuO|D@{VXR@u;*?+zVW?#Ru|P>POW1}XkE28+g%y<O zYFM&FYZ$XcvzaF_78%vBED!_n7Bbed)G#d&U&ydPVj)8<YYF=Tff}X-kT79h$h1Im zAww-&4ch{#6vi4>P<ViNNG8?Pu)$4Yt6|8Jp1@eNsDyoi3@AfnNrU_@!640$0+y3t zSjY_0U&AQEus{$L-XI;~47Kbv><eTUGSq@HkX$Wi4NDEHBttC=IE-09ZcJgRVG&^{ zVP7Dh!dSyD!BE4wkV%B0hC_s*maB##OJM?25lao%0>y<4wOn}|HC&)DtL3WU3TDt` z_j?J-^hKal1I}_U|NsC0ze>z6zeE9Ar73{(j6z9eL23~z0|SGm5;*<eVlB={EJ(e@ zoRXP)i=`|xwfq(fDC68>FGwxQEiFmC#g&|&SCW}tnqOLci?sw)C*NW(E-Xz=tx7Ed zWu99s`FW|ux41HkGxLf|K;`HymYmGul3Pr91-IDqQWA@b5-V>p=cMM{;&uboC?Vkb z=@xTwY3?mnaCr&Ndbc=3DhpD<IrJ7=N@`hVa_TLnlKflD$pxjiSaS07(r>X=r55GK z-{L4vPAn-&Es8G!<<?uQ$vKI+1-F=U^3#hz8Tl4-L4J7=C=cIa$xThn)8sDVXJBBs z#hR0!o>_8>tt7E1J+&l?6JCo#yuqBAQW3?SpO+e+kzW#@omzQ|sh}hZq^Y0;VlE3P zOBV@&N(e9^3^J8HEx)Kdu_#3ulsQ8{*_DwC1eti4q!@)5g&27ld6?OlI2idD1sDYw zg&6rjB8+^De9Rn7e2hFyJWOIB9wQGE2O|rk*nbu#5k?LskPa?J872-!E+!7fDnUdM z0kuz)&Cd@J-$kHyVTdL>sL0Gq%uS7tzr__FpPQdjnge3<#K#wwCgwn8*yH0<@{{A^ z!9`P%Imj2h(4sCmH76%N9?6daAm4&K3NDdBWmz$(DB@sb;p5{4i-F4Qywco)$|6OO zS{aaW9N=`KmzQ^oDKGDqFpQ^{R+<NH9VO=6;zD9Gy4_;)1-lh&B9irpaOSXq_{I*D R(2GGuHwPmJBM+kxGXTzdC@TN} literal 0 HcmV?d00001 diff --git a/MyLoss/__pycache__/hausdorff.cpython-39.pyc b/MyLoss/__pycache__/hausdorff.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4472c5317ce1e97640f7705acdd550d6ff75b161 GIT binary patch literal 4122 zcmYe~<>g{vU|={fWp<*12m`}o5C<8vFfcGUFfcF_r!X)uq%fo~<}gGtf@!8GW-!eX z#gf91!j!|3%NoVX2vWnG!<Ne)#Q|os<Z$M4MR74QxHF`%rm(dzq_Cy3W^p$&NAb8b zq_C%Ov@oP_Kt*^{7=sx!IbVY8_tRv&#gvz)$#_dFC9}9BF)ukazN9EIuQ)BgC^tSe zr9_kQ7DsY^URi!lS!yy!2Qp@ba*87u7#LE)E{<YKVT@u<;b~_`V@zRAVQJxvVo7CP zz_ySfl|73Cgi~1OFr~1iu(z-@Ge&Wy@TRh6alvp3=N#q~t`zPT)+p|F1{Q`Wo?r$| zzFYhriKWFU`9*1IE+Ibo#l=;IPMLX$MU@J0afO`x;$nrQ#NyNxh5S4Pm|qm&et8ac zszWjplm()g85kHqm>U%I5)2FsB@88uH4M#63z!x%Ff!CI)-c2~r!WLFtYq}lWVyv! zTw0J?R0Q(&Esps3%)HE!`1o6_i8%!siJHu}Sc+3~(uzP4eTyx-Jh3RfcqQX4&iMG` zoW$bd`1q9!zs&SA@^e#zOY(~h^(%7{^YqK}i?WLg5|dN)LqMVK>*=HKTM6=_eg@3T zdIgn5oD2*MVjwqy(g_12A7hmS-0P4a(Stchlj#<7W^Vc|wxZOM(xN;R?{Y9OFn~;O z1{oCw_U;148pef;wM->UCCoL9&5X^AE)20&wahh)HO!I>DU8`HwJbF(3s@F16p7Wa z@H5o1max{aG&9z+m9W;ZHZ#_;*Ra*Fi!;<P)iB#I6pGa_1v6+e`9XvGfxRZ%EvCGJ zTU=GCMft_?Ihon1w^)l3^U_mqG3TV_-C|D6tGvYwN@pBJphzeZWME*p#hR9xnv-&i zH7T(qIU|ZaEj>Oru{b-5JqgB0OOFRdF<1n|DB@>eV1N(;AOrY7g5bmqHc1NPUm;M0 zGjTBqF!C^QFtRamFmf>QF;&T-2PCRix0TGdSPF_#Q*N=9Bo?KomK1@^Nd|=?$UG1R zML0VH0|PiznHU)uY8V$VEChueQ?UY+&6Leltl-6v!raW*#0U|qVFIbjW-7MuVn|^D ziI*@gV5(tQz+A%!3hxE13mFzNE@TAdomv*S9215dYYk%!lLSKwt1v??TMbJMLl(O* zLk%m07H5FOkO)I9M-9gUj)e@hoHfivziK#P=G1V&SSf6h3^i;hJa!ZwI30l81y;$5 zESJWb#uUs@!x_ww!XC_^$?<DCmjV=o7N-^~<maU-BtbH5az<iaUTTh&f}w)0LRz|x zf`NjrLQ*<hJt(xT6f}}_6by9~Ds&Vobrh;J74nOag*4%+A+bWd`j?=ndkM<Iza;$g zaw-*!Toe-XQWT6`6pBj=3i69eQd52z1r(*GWG0tn=I4PWA(klQrKYARl;kTUr7Gm7 zDHJE>rYfXl=BDNqXXfX<1eGhAOht^Kln5gHG&yfEr)1{dVlBxpO3t{&l~$CW8=qI2 zTTodf14^XK$pxjiSiu6fSkrRy6H6dv2TN{hVxA@!BtOW3jFSf`XH7{>Dou}K2L(@j zN=Xza#8L4nCAYZ1VyQ(%`9-%lVCgrCBc&uBocwRGfJ#S5_5f!Ta2{b#%P%TVEK1P^ zWezP+Lgr9l;``6S#PpAaO@>j1iHA`HoJm-ixS04Dq3}1;KNg-Uar8_BHA|DRh>d}P zL6gx<lNVG%<|XE)#>d~{ijU9DPbtj-v3cU-3riDopfc?7@hSPq@$ujiu1Fo^Pxie0 z`1GR0lp+xbFD0?4C=nWDY$2(6#rZ`=h9D_kXvvzKnv)YBkK{dGP?&-I2rj9GLH+|} zS`J1QCO#H6MlnH#|4fV=j4c1a^0;d?P+g2ttMP$rHI@`sP^|{8!x*AiQ`r`<FJwsN z$l`?I6plH}DV!->;2MrAg+G-oiyMZ)ay%)#Ev!*I;2MrMm_btjsfKe60u}#NCIo7@ z-29?~jQpJZ^vvYM90gF}nwg&m4*Zv(@+cWptAXMagh44D#0M3y#cJSUAG4OLVX9$> zXQ^QTWp7ZEBAB5FRH$jP6@hAvA}dgsf~p}%8s&i6SbU4WpeVH<u_!e@JGCe;HK$mU z1(I+<t|+nr>EZzq;5dONK5%q{DxhMJ;h+kQjgbwxQiFs7yi!XBWhjtWKp5l|P$~nL zx1g3s3Bv-$5~c;rB}`c?DU1u5(m^#EBS<uzA%$rXV>&|$GnfRinZdG5HH;}NA`CSQ z@vI<qHB6u~G@W@NV+t!+lnqqrYqI%iGQ9*@S7ZlD)y#;@4Q6RV1YUw12Np=5H4_9* z+NV!a@tmaMdB|Rq9g>V$!HM`5Q*v39XnJZ%a7j^Va!F}XW?s5$PHJvyUP%<UuYXX0 zN4#^8e{k?E4oF}cxxiS)E}G2XaJ|KvT$Ep2e2X<HzceoeDba(|yDb9)LkuXrLxNR| zQI3g=u}U7ETcN><7G*`U3=9k?RT?PDKt_OrSPfjIfvXrsVFplhAd4Y|DVwQCu7)9t zv6eMYrG_PoA%!`csmP#)A&aS&HBYF7IfVrzI-4PdbuPHlD`8o{TEnuCv6d~Jp_aXb zt%M!ajA#ZmH#XIB)UYq$C}CN^S;JDpA;|!$UO-hJBSQ&W4MP?eSXEaoX9+h%zJ?Xl zq)p-MWvS(=;Y#6>V5sFzVXNV;;Z9-8X0TzX;a<SAkRgRFo4M#5NOS@3LIw$_s5C<@ zcMaGa8-^OT8V*(lNroEk8qO4MafTX(EIwfd5e9PxW{@e&MNM#baU<MR*i;x(!#shp zNDkEW0atr$!3^L8!{b+^3QFxtAVQIWfuV}a1y(^QR0%0$7AxfC7a=MdJ#exCx3EEt z?~-&-`UYW5rduqz`6;QlSdz;UbBaJ^Sdl#_4f5Y&$xF%1Eds?%krSvG<O)m7DNS_+ z75?C~RwN3t9MoDY5(mk%B^DH<=B3<XNzN~*gycw2fpCktprqgyb8coHI7e&pfRk4d zyuK;Y1sUTCHUW|vZ?UBo<(H(UK+KQ^$uMSv>YJQIP(>NV4lNVV66G!ag47~NccLgC zR8xWq++sbDFIjk)xEMJYxfnT^gqXM(S(pTvC7Ae_MVL7lc|g?<GY6v_6APmRGaDlZ z6P9F*YOWhf5(X7sAPfqq;zbPPx5{dmi?V7MK<#8ur(gl|LWUY}_W;adUdSB7RLfG! z3TY$PfLlQ&tP9xi=~&1Jsxd$r3DknBWvc-tAJ$+7O|~KhP@sbnP7yc>DTBBw3=9l@ znjA%-lDjAvB%%o-Kvi~;4u}P+Hj6YsEKpsp$&NK913>cN^4I{x0vE+#0$gx_D%6Rf zf&&!)3`h+D7Dg6EvHu**Jj^VNGXFVPIGB(se`s+6tNfY3^-qyCD4jC8Y4Q~rfow7c z5hftQ6hwgQ5KuV<uI-D=Kw_YpzQ`QJ0(;K_!~(gf$P&Z?yAn)*D-p1JLDdmRZ850R z2CmrIIM_jLdX|5DdhqHUq)L-7iaEK=sK^my6gbSlwt>S8)MSg2C{E5SsMG`Hs>JkE zJ$S`hSp+WmB|+(&7e!4GIRCKuf}1qpHUy-41iKMIfE;>@!v@mwu>(~S#e57544{4o LBz^O63poG)I}FHg literal 0 HcmV?d00001 diff --git a/MyLoss/__pycache__/loss_factory.cpython-39.pyc b/MyLoss/__pycache__/loss_factory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed7437016bbcadc48f9bb0a97401529dc574c9b2 GIT binary patch literal 2411 zcmYe~<>g{vU|?WmS4vvQ$-wX!#6iX^3=9ko3=9m#T8s<~DGX5zDU2yhIgC+^V45kH zIf|K)A%!`GC5I)KHHtNtEs8CdJ&HY-BZ?!JGm0~pD~by&&zi%X%M-<u%Nxa;%NNB5 zR>vR3pDPe0kSiD^m@5<|lq(!1oGTI~k}Db|nkyD12G+-xBc3Y}B>`r$=Sb#CMM>pK zM@fVE962(%vQe^NHfN4pu6&eyu0oVTu40s8u2Pgzu5y$zBgnm6IV!oTQEH3~DWWO7 zDSR!AQR-mVYeZ?J@TUm0FhpslYNZHfGfiMDQcKlrW{A>GRZrD!W{lEF)lJoDW@KbY zWn7?_B9tP$kSR((l`+d8MJSaq%P>WxmuZ1fits|lDC1O}RQ*)rW~L~U6sBMXO|dAp z;*7-ns^piT(DZxB$iTp$$##p)$;BtXxcC-NaDG~eOJ;H^n8)Yh9G{q%67S>^p97NR z^Kfwqamvg~EUE;{i$LW)Tp)rFRYESA#U&5{m`R#Ew>aHloLk&*W|+|}w%}lpSw)}- zy~XaC9}42#;tnZGEh^52SR?3`pPZNj6@l2n;aFUm3swnpqq8eSMj#+HFEOX25^l#W z0jMiW@(Z#dYPel13Via@!5mG-Tby9?L8fUk-QxF1EG<sSFG@>u2>}Tq3AhG<1T-0M z@%rSKB^FnK!YVhhLX-6tpL0=uadB#1Nl|`5rJu_!j*$EUZ;;|!65*+t=@})dDPR?@ z5EWn<DO@sPMz{D}GK<0f$gE0Dady>Yyu}8JfYjpWpjcyI_+SrZJg`q@h6sU4Rt5$J zHU<U;XHd>?VPs&aVaQ^rWlmwJWhr4SVX9$nW-3xCVTQ1p8ERQ3Fy_hBFsCqPGZk5s zuq<F*2od9`VNPMnW-9V3VOhWi63hEm!<@pL%~TWv7kN^{oWhdLRFqf3vVa||=28uF z3Trk~Q9}vK0uHdqfg0u%wrr-N1tlyCI7_&i8A@0da4%%2VXR?V$WqIax1feOg*}_8 zXcI)dmL;#RhB<{Jo2lp$k`6W`9eFj(DV*6%MW2x5d64A&YM4{Fz&S<)Nd*Uz3Y{9} z6mGBzof4KT))bx;-d@I9mOPOZz6p#)7AgD_7>jMZ7@8PMc$*k&Sb`Ze1^l8|(?Ch^ z7JF7=a&lr(N)$Jg5ucNvev2g~GdVSi14M%ORa{BQsoFW<Bof62p{uw-Vo*(0{2)$F zetJAifCnZ3G9!w&C^e-tIW;97Y-xxl<1Kcmh+7qpXI@!iPG*V%C_QNMMR6u27N>#} z@-6m^#Nxz~lA>Eod3m?k(^E^p97dlgen@77r}8L4h=3QwGoaKM#Rn0B<x!Xt2tOUH zgwrcOGp_{1yd~)5>>8d~lHrq|o>@{15{eRl7y!u`U=b0B2;3I1DYv-75_3vZU5kqH zi!@npu_P9y7vEwkH@wAEZWP6toS&Bl@eB*de=8Y^bQu^JewpZJ<maa9SLP(<>6hmh zWfvDDCa3C$6eZ>r`+EB5`+|!JeUOg$w8Z3+{Gv*|g34Rm$wjG&C6K@o24x5<P#$4q zV`gFGViaKFVB}%sVB|u>5>P&fmSE&!1YtfV0Y(<aB7Fu1hGdWtFwDfjzyQi|;M}Id zz`#(#uz+zPLoF!RF@Z|*66O@<US>vy8ioZd3mIw|YZ!wWG+F$rxOH_w(WVP>@ht)L zl0lQ{7JEu+T4HHV$t{+W%7Ro))>|BjMd<~JMa8MN_#BJUOLJ56N&-OKqFX$PDJk)Z zP_bLg#U(|zxWFnwg2kFlx7fhEq9RQO21sy%!?Z|`fq~%`S8if<YCKqhIVglUKwe|y zU=(2FU@Wo(so;o@&rQtCi;sT^^6g7dVWY`-i!}x8E~X+mkfHLRI>huASA2Y8X-P(Y zQGEO@){^|9<cwP!U`j79uLu-8w|JBCOY>3^iz>l!Tm&)<T<#ZvLa_)`a22tG3<Npj z7AGh#KooL<1VQbAB5n|i2So6K2tE+O4<ZCWgdm6z0uiA2zQqO3KM*5CK%U{ufR*h< zq99dbAVF?O4uvQa2Z`|exx|Ab1zMmNNq`hff(R)PAq^sA7#J9$1PUs_K8!EP&(BFp z%_%981<9#_2vBk<G6u0=0hyPWn;IVvj)*9>+)8llvlbL(=9S!HElSKw$-Bjpl30>> zi#0X3prrB^OHO`X`Ym>lOnz?sE!LdGq|}@u9gvBjTmgwjFbf=)95%W6DWy57cAzS~ q7~}*NP#R}aV6tH5VH9BGVU%I!Vd7HeVB}!sVq|0FU<Bm=F-8EU2(%dh literal 0 HcmV?d00001 diff --git a/MyLoss/__pycache__/lovasz_loss.cpython-39.pyc b/MyLoss/__pycache__/lovasz_loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b50bae1fae55772838d508428f39b9eb9d4ca90b GIT binary patch literal 2237 zcmYe~<>g{vU|={fWp?6yP6md@APzESVPIfzU|?V<Hez64NMT4}%wdRv(2TiEQA~^s zDNHHMEeui2?hGj`DXc9FDXghXSuD-WQLHJ9!3>&gFF~gHB{PC3C}v||U;vrx3^KQX zfq|ifp@yNEv6iuhVFBYphFYc?rW%F?ObZznFsCpsWRzg2Wv*dLVX9$}KoS*asAWlE zu3@fWk!Fx!sAa8TOkv7qDaxr~O<}5GY-N&Us9{ZEOkn}BBtUGis0~96YcPW*YgIg# zf`WpAbAE0?X-R6aLV8hRN@i+ai9&vwLP<udf=_-~VsVv1YDGzEUU6oAo<g}^kzR>H zaeh%rYKlT?QBi(TG1!FQ)Kmq>oOC?}Lxs#dg@VL_)FQBW5y-<YK}0&l*B|U({{R2~ zzb4Bq=A6{LTg=6!xwlx;a`F>PZm}ho<`$RcYO+OfrkBJ+?6}2P5XGEcQhbXiGp{7I zs5mvbBr`wn7Her9h+@x5OioTLN?FNJ#KXYA@XJC!BR@AaxFo;GP`@%KF;BldzbL!7 zATc>rKcpxzuh`eqN8h*7C%?E@KL_ly_?-OWV!eXOTij4SC;-?&@gV?mE-MEk2cr<9 z7$XlO2crUG5g!8sLo!GL8H1t_#4k2uU|>jPh+<4(h+<0NYG+7e1Sbm4C}v2Kh+;`$ zOJQ$eh+<9QNa1W@h+=DJU}1=24`$HhzQqd(>frpelHA0KTP(S$iFwIPAakJ@#Aaq- zU;yD_32<sDVXR?jW&))UMur;38isi06oz1im5hFxEVo#TOAAtqiuf5A7;bUI$7kkc zmc+;3;w(x{DFp|iCetmJ;?$h9A`u1#hLw!BIOF4!a}tY-<Kw|WS;Wi0z)&R2z`!65 z3JtIy8TlBi#88~62QxSs<SM8`89?b86o%k*-NC@XP{X)@VIe~bV;MuSfEPm(V-4d1 z#uTQ7jJ3>A8Rl%JB99uz1x%oTN?}W3TFA11c_BkBNF_@OdkynKrdpOfJ*aAqY^I_z zB-P9*Obb~t)z>gBU`b(L$XLr-!&1YV#uUt;$?2!b^b+Kkm#iS~fr7LM6qJxu3gW&5 zd87zTNHQ=mXtLg7PRY!@#ZsJEm3oW4AhjsBv?TQwS8{${NoIO!erfS7ma@#$a!s}( zQBX*;X66-?mfT`1Ni0fFExE;+SDFh6)+k=EOnh2SVo6DAUKAfh1&j{~b5I%v2mCFb zf};E+n4~l)20+ooz$n1T#mK`b!pO(S15TA(j8#&oF@mfSC00NY3M%8kv2qI<E1+_v zmN}iFmZgNThPj#1g(0@1mbHedhBX_Uu4>t8m{S<DnTi~07*iNh7+RT9n9`Ue8EV<{ z1Yjbdumy>fFlVtWV69<W$XLTJ!LWdBAww-k33C>E4F@P2q_Ffd$1v4$)^gRb)o@8N z)N<FbEZ`_%Ucgzy3DH@?Rm0uPxR7xn(?Ui@hC;Ct<^|j}EDITHS!x)vcv4uinTmGR zu+;M8$$<H6*-S;}YFHNVE@Y_X$>XSDS-=Nk)$jx}XtMhifg%$ets$C>x42SLi<49H zQZn<>Z?WX%=cN{b;$M@$NERGvMTvRosYRe<dyBOszbHB57Dr-fNj|77xWy5cSd^KV zl#_akIWejD77M6Yy2X}QP>`CJa*H`7zXV(;++r;*NleZLXC|=OMWFISlNXXI*uZgL ze2X0t>BYCW;prl^_!b+e3`#A&#h4t$0Z|a29K{aS5}zCemW@x2;svRSPlgtYQCwhd zW?o8aMHD|w7gX{V3rG?@NeM$sxMD~N15RA(pp>No%D0SMj4X^i;BrogNq|v=Q3PD# zaWQc*vN3Woaxrl*Rw<&UG%V(!Bsfq$17UE2QvfA6h6SKvifJKJ3{x#*EvRy1tYNHS zTF6w(T*Dm9P$US-aZE)Z2Ni(|Q%x3d_!Sv|GOHnoFa{CeYyu`gY5Ep>T7FS^Vo{16 z$TOfwWMC3t;$f^3MRhJzl_n!3YB)fpVqRiyYJB`HuK4)e{FKrh5Su4HzOXbg2O<M5 z01@sfG6y9WUT8^~oSKsZD=opU5d%36<RoyZ3*r@XFfcH1F!Hf*@NkHOMK#%qG(d)M wfYX^?UfwOHygWz}V)F&N0jvxu^1w!eOufZn197Dts5&hM`ILi^hn0gL0B6=J`2YX_ literal 0 HcmV?d00001 diff --git a/MyLoss/loss_factory.py b/MyLoss/loss_factory.py index 2394abe..1dffa61 100755 --- a/MyLoss/loss_factory.py +++ b/MyLoss/loss_factory.py @@ -34,8 +34,6 @@ def create_loss(args, w1=1.0, w2=0.5): loss = L.BinaryDiceLoss() elif conf_loss == "dice_log": loss = L.BinaryDiceLogLoss() - elif conf_loss == "dice_log": - loss = L.BinaryDiceLogLoss() elif conf_loss == "bce+lovasz": loss = L.JointLoss(BCEWithLogitsLoss(), L.BinaryLovaszLoss(), w1, w2) elif conf_loss == "lovasz": @@ -62,6 +60,7 @@ def make_parse(): if __name__ == '__main__': args = make_parse() myloss = create_loss(args) + print(myloss) data = torch.randn(2, 3) label = torch.empty(2, dtype=torch.long).random_(3) loss = myloss(data, label) \ No newline at end of file diff --git a/MyOptimizer/__pycache__/__init__.cpython-39.pyc b/MyOptimizer/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..058b9be16d0a0efced0d6e4650eb2c5eb3a2450b GIT binary patch literal 632 zcmYe~<>g{vU|@*cJ14P#k%8ech=Yuo7#J8F7#J9er!X)uq%fo~<}l<kMlmvi*i1Q0 zxy(__U^a6OOD<~^E11oa!<Ne)#SUh(=5XY4Msb4KY&l%H+)><MHhT_FE^ibsn9Y&H zm&+f;4`y@b2;>Sz34+;NIYPO@QNmz0caBJ|Xp|_J&66XRD;_1z$dJMt%%I8ll97Rd zL6h+ot7A%HZUBrCuE}_d(=jD6Eit(yzetnu7MEj6Vn%9lab{v3NXjQaKRYoaH8BOG z!7ni-F;|oE7KdMcS-yJ_Sdh!F42j1Y1Tr0@F~~Q#peVl}#7&d&7E7?Z3&=)+<f7EX zlGOP8f|AVK%&OEPKTU}!R*<0uMW7J81!k8QF*7hQL~$m<TvfyZ;&Ua!9ahB3z`$^e zGY96jA~q18H4kJ`5j#jRM_zteetJ=2N)ZQ$%avCK7vKb|Edr|rMQRjRQEoBFsqrOg zMcg3OEXC<51w}j{CNJ2j@sQA|EaGKgU|7jeB*ee~A%2<aXXNLm>R09@=INK`7iAY0 zBqpcohZH5|75jSn==)as!@Q~=AD@|*SrQ+wS5SG2!zMRBr8Fni4iqQFVhjunJd9W% F69DH+p%VZA literal 0 HcmV?d00001 diff --git a/MyOptimizer/__pycache__/adafactor.cpython-39.pyc b/MyOptimizer/__pycache__/adafactor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5594d367f0f462a2976347a368a45da43954538c GIT binary patch literal 5478 zcmYe~<>g{vU|={fWp?5|c?O2ZAPzESVPIfzU|?V<)?r{^NMVR#NMTH2%3+LR1k+4W zOzsRR%qc7_3@I$Bj9JWCEX~YOtSL;v44SM}{R)mLiD`++CHX}P{skqOxtUd|MO<7y znQ0}dDGF&t`MC-iB_##LR{HwsnI##eNqWiox%vf_CHY0k8Tx67nMK8^h5AW3`APb@ ziNz(UMR2ivkZJmfFw^x4D!I7)i!#$Q^Ad9uGEx&$Qj7GH^9w4AGSf3k6p~VN^2_zO zxL$%h=9dg2kufKfQyjs-z>vxi#hAhn#gxJr#hfD6&XC5K!j;0^!WqSq%9_HH!rQ{o z%oxR%%9_HL!VhM%r?RF9qzHo994SI6!YvF@oGIcdA}OLRj8R<e3@i*$+`$Z*61O<v zp^<3F#=zj2Tac5Qo0?Zrj1)|XIqCUDnI##ydRz($3JM_^nZ*j3FolUFnfZANnZ*i8 ziN&cY3i)|f3JGvCtw5#}lw_8rD)^)(7UgB;r7Hv_mZTOdlxLP?C<K=#<z(iiCKf69 zrsn1sRVp~=7ndY}Ezu}WO@&89Vo^nAnO=TTx_)9(vA&^&fr*}hiIK6XCKp((Uw%od zLP<tqi9!j;Y2bJO#i~MPUP)?EUSdv8r9xs#R%vlbu|i2kszMITDMg7TsR}8n1*v%{ zAeZImfuzAkY859Z=A_0KBo-y+rk11@Y3V3v6{Y4Rg1i!6T#{O#rI47HqM%itSd?2@ z5TBWsS)!!?vN%7lSP$Z=kbH&G;#7r1h1|rv(!?AEjnoQ=Yc$auP@J5Rno^pRs!*9< zs!*JfUz(GmP@Gz#kbvd}TerlV;?x9?#b7rkpg0IB3$@&_D7_RMjaFc#plAd8rC32D zvm~`BF)1fi(@FuxRLD<L084|sTC7l#5B3TuXDH+sDWqg3mlP|cq^4zp;}5P-A-yQS zw4fMnK~9l^Mp{mOVu=n|dwyPGj;56Y%-3k1)<{WBODxSPu~P8M&r8*W8<JX3te{a+ zT9A_(4K)i)#cEn96s4w@<|GznRzZ?na(-TMNg^nvrsWqY6c?5z7Nshr7bT`-rskEv z%>(%y6jKOSgF{uJD7CmCH5nA0Ih6_s$7mR)>KYs9D1azUgsYNsG7I8MGKx}*GxBp% zpx)H9Qa}jhrzsTW=a(qtrY7dW9Fw1>kOt2DFjpy*7NmfZ9Kt?BJp+U-DXGbcmGR(2 z2?<C|D~06z)U>qBWRRag$teXCIm!9C1*Ija3Pq)P;223POD#%FPb~&TBU(5hc}>?q z&jMjrQff(}A;KdV)_}qkVh<>1V6zGCDTHrOqp~~|RM5qPeFe1%C1pV56`=AyMhXR~ zd5JkCm6}M7MhYEpydiR=f<{t)J}7}@rYVFJm8R;TC9uq51+b13g``TXQG-Yk2oHb^ zLWB`Y{=&4cBr`Wv7gBnp=9OSt4+*~16qK}>nO9trnwSELDuOu!)ip>(4cskwodB^P zl35_mELO<RLvc-cMruh$YLNoSY~9iV1yDJbnV191VxUqz6%_xaprRcSOyGhD5>9^4 z7aV&1&)Z{*<?}x)KVI@p3b20;sk%PcKc6)d1YYikiZC#QxFvH#WWgjK0|NsO0|Nsy zs4kCVU|^_WN?|BvsO7HVDw3$-E|w@^NMUSdY+|fo2Z<M{q%h>M)v%{9Nix*(lrYvX z)iBqv)UejD)$r7C)Ns~t*KnmU_j1+pmN1nt*Dy3QEnul(tl?eA#K@4s5X_*-;#Z{( z@+!2%Rwzm>EX^!REmlB@X<JaZ7cnw0Flch!V#+DH#hhADe2X)+pg10!qHgh_76-Su zz)2XKac{ALGuAC0NaBVh{#*QDeTcMsix<VhTii&F(qz5G6_%J&n(A6qlwWj<1yuUp zVl6H$NG&R2Vqjpn#Sss#ZsOxL`EIckr{<*HVgo1T;vx<P28JR~!=VV&d@2HkbP*_c zA&x380tI&wKS+@vh!6k~ykH}tg>msp##@~6@yR)f#l`XQD;a)S>u2QWrUsYf7a8hT z<|O9nm**E{7Z)TZr|O3kCFT|Tdiv=5R>GT6D2<;YQ3eJEZBWb#fg+oUgOP)ggGq#u zhmnnu<sTb!l@Pp{pa=6(GRTJ@)ga8uz`(%Hz`)=PG9wF=Iv7$IB^ip6Y8XH%phzZ# zxrQ-?MUp{;A&)7AwU(&_lz!Ms7#A=>#1}G3Gc+^RGN-WBveYo7uuC!&g@MvV3Wp>^ z4Py$YBuJJ8l!inYY8X<uY`~->D5Y@w6@h~X6j15CZ?`V^er$u?Ef!GCpI$Ap&i!Fk zi+vHufspV7h4(GyAm3nc9B8uKV$RLXyTy{5SdwvzrMR%D<Q8*oVudE_EpBjB#e-|k zC@816B(Ws*7DsMo9=K(3izTNhz634_YKIibfT9Xa$b#aFJw83PBtECe927DjcQP<> zFiJ2=FtRXmFbXm9Fjfg7g)KxaO3;ED(4e3O2dx4F149Wz4P!H73S%~FEfXYoQy5Aa zYnee=EDh9fX7+mt(oh6SFPe<Em~&F|G+ByNK(2v#BO|dO^%e)DX`PyKOQ5tkH9jq~ zsJJ9PHy_j%LH7j@*cZ@7lM~1*Y@ldo;sJR;49x>jWyzpy2~q{ZAU3Gz0Q*7=oIn;Z zq%baItYKKdxR4=*Nti(z6q(F^MH&nY4Cx>}3=9wKHJNU)<mDIT-eS!w%}veGWW2>z zlA2eXUxe;LmiVIFVh@nZKz?Uns^UX(RZ(tnGRSs_sSFID)BwU@*O-A^!&t+Z1@4-p zF!wS^gWa@%3FJOdxGrEyVOhwyfO#Q9ElUZ@0#;BmTnGwjR==13|Ns9VqRDiNIVCgq z7JGJTYC%e7?#qAw|NqxyyTt-(V%%aY0wt^XTb!kN#f7D*sa2`BSiwTKSiwD&Tg<to zIhw3RpkTShms(K}pIDY2UtAbplwXd*PtMPY;wXxTL|YU`GL(fL6#Vgt1qDU<6(A!( zHNOuiRKW4YD8tCdSS5)RGN>w15+0~102Nc<KvH97U`S!C1?5)8QpO^o8pdV@Q1Z)S zaABCh9LrYAT*C~?=`IWtm}0-xved9-F{Ut<G89=9s)5*9Otq{vtXa$@ELp52ELm(B zOl1s3Q8la!*cUQ@WjKm5N;p%P!8Hg&7FP|6BttD*4Qm#6Eqe{a0-hS?8uo=uF-)}_ zwVWlqHEhj{DJ<Dc6BvtxQdnx(Qdn&mYB-BFl&~z|t6^QpSi_dWX2Vd!QS_^XWdVN; zdkSL;dn=P9!$QUd0#JDl8wR)>CrGY_J%v4m5yTQlR>x(-kjIB)F1HN>*hU^3hP<#E z&Iyb~S|Hmb8Q>;^ZJxkb6js9qD!t(195xI!Y$;qwDnVuo)v%>-BZ+a_Fx0SQG1YQ` zNYwC}z*y8%!vzj8E|8yD86+7}c+DA@89=4s0>K(Cu+MqGJfRxq8n%T@wcIIuCBh3t zK<XtKYM4{_B^eenNifv%)Uc)q2!lifB^ksSYI$ooCNLFwK-^r*2XSdFe+}OP(HfqG zj0?n4_-c407;1R4#8cQ)gnAi4HFXU?7FkZXEW-kc8onC-g-o@)HM|Q%K_Sn`kf#Ij zajig&0J?6lnF0$W7cxLpgY`+(Ff0(P;a$jB!<-^4$so;;C7r^TBGSvakdcuAY%VvL z#W{hg=o}<2Y6WWq7l_s{r-*{%Tr7nb!Unk?Y=>YiNJLDMA%!KIc>+_>ry8yWVl~Vu zVv-Ep3^m*$3^iO}bs$k*ZU(UVELlvb@lgay*-Q&$K;h2FP^eZYQ3EOz#H*GDq!y*+ z7v&Z!Bq|hV=B4MPf_quHu(legYJ&97L0uj!{XdX!a!!76X;G?zMsi|K4yY4|(ja$K zfXRb9(?zMNWr;bZ;O;u8+nk%9l9~hReJT{CmXsFdK|9m=#l;AF{ffjH7#ON_;H7~= zN`7jwLSBA}LUCz9L4Hw*LUBQ2Q8A*YR;&kZY8A<WsuFf+K~-b~s_<+;gei!y1xa#2 zCF6?=i{u#?7+!+Pe@G<^YUO~-_##mAN0YJ07^Dl-0McYE0@ZKedJEhD(v-c$gH#C@ z-(mq3XSY~V5=#<qv8I$%7NizgfvjT(cRUTvZgC{RSgc^iEzZp1c!)1<@dTCTfg0hE z#u_W6DrYSQRfR>mAU9}&jN+<FEy^#B&&kYAy~PS<-(mu_sXz+UL4+bm6>~v;`7M^* z(wz8PEQu*8@kO9Yq{t2=ZwZoTNy#jWzr~iElbBl&e~T3~loYQi1S#R!p)q%hB`3eQ z_!etAsB3<Uv7iW4&=z?!Ffc@Mf&3W{aT=&uRs<S6i{dVb2l+fct-#o*NDZWp4b*}) zD!Ii5>8uukV*wNaMW7+KB72Z~z)cY_0cwwc<1P?XAyj}WLsn2LgpG-Zk&ls!kqO)) z;b3HA<Y3}rWMdLw;$!4sVqxNAWMPzH)MDae;$mWB6k-Hn5hfufHbyZ<4kiIcK1LoU z2}S{sK1L2EIYv2F9!4G}4kj@sJ|+c5kRA?3Ij}im%sfmYi~=Z?Gbl4?a{3j4>gOU* zV+WE<K=qYh5vWZW;-<;yr^y3qOy?!$rpCwL;);*Y%?Ay`#K+&_iH|QVP0WGHu*b)z z<R{0+gB!HBctCNSnVg$il98WM1dcs$L>q#_1sqx6l*|ilMklA{<iy7#)uvJ)b>L6~ zwJAW2w_*tf1_lmB7AXcshW~s#OpIJyOpIKtOpIL29IW6rzb0D|sCNWTmqp;T!U}2+ j<=*0iH@(5?kgS16&Mgibh`;PW*`*j11T5Sfj2z4W7lZjL literal 0 HcmV?d00001 diff --git a/MyOptimizer/__pycache__/adahessian.cpython-39.pyc b/MyOptimizer/__pycache__/adahessian.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc22c127ae8d56a626d590c23499817208854192 GIT binary patch literal 5843 zcmYe~<>g{vU|={fWp-k_Bm=`^5C<8vFfcGUFfcF_t1vJyq%cG=q%fv1<uFDux-+CO zr?9jzq_CtiW-(<kH#0}Eq%Z|DXtGw#R&Y#7^hhl(&P>cx@GmII%+0JyE#l(x$xJIr zO;JcI%Fk8EC@Co@w$j&6&n(F(P0~xw&(%*!EK5l=HPKH@Nz~1N=+;ll$xqVHO)M@+ zEdq(hL&fz9D!KfNGSf5j5_57Y6>>6@Q}c>bQxtqXLv$3J^9w4AGSf3k6pRdv40IG+ z63a4E6oM0TOH*^WUV=R5mkc72F$a`WoWQ`qkjfCnn8Fanl)@OroFdlFkj9w8nZnh= z8O4&yn!=sJ)56fq7{!*to5I(^5XGLtpCZu05XF(oxqxdSLyBOE5LliY$`?-&NfB*f zY-Wt&X=h+zh~f=q(3H5v<(QHP^-JPAW(Ec>1qB5K&)kBX)ZEm(l46CDj8x=6Ow38o zFUl;*$b|%+lB0{Ghih=Kr=y>hf@2;?WkE@1S*k*CYI1&Fih_SpN@@{2;1$yHixhkl zlQS~&QWbns6N~aP^U{^Lz*agIrI&(jv;uQMt|&+>O3W=*(8w%FElNzvN!7GcfH4*F z(-gqcsU@jJ#R?_)3i;63Rmd+=NXbktDON~HP0IvXfKaHAUX)*2Pz<*qr$|8~Ehj&* zL<g)rKQA#y(@G%+>M(_(#FA76jg-{1#L}D+D+L2RLru8Wq|}ncVg(JDcCaa$D8?k` zr>3Q4CTD^}u@sahK)z4T&n+k|$;?YvC@ReZyDqUTwJ0$?wHV~3^rFNRkOqasycAH# z!>t1aM{!|kVo_>}0yMT1N{SMbk=>_Zpl7M0V4!DdX{m{DXlg+*HcytM7Ue1=rlh2% zfC3K`$tkIM`MH^Si6!|(Afe3Mf};GgRE50K+|;7X<is3>;*!Lq%$&@UN~AC`Ox3kO z*j1hiO6>6|smY0z*er#}D?sIaj1&q|^AdAPDm9TzG|)3Z7zp)hd_jJBD%QA4ttiOP zOU)|*1s=!~Pz@MiV~A{OX+cV2Noss*VseIpMrIyr>_Z|b6%_htA)Q!IP?TShnVVRW znV+YSpO;gqkeF7ITBM*=l95@g1q!L8)FO~eic3<FQW+@jG)nRniWAFH6-qL5Q<1%k zl%(?FixWZl4wT~29F>t@u8^CUR|(b*PVNw;3JFyS;7kSbJDU5_uy_HHn4$JQpEVN% zcG|CGN<VuB1m4)EYqyKtJoo**{d21xgGYUpu6EC%WfB9F@xlHjBLf42TQa!BhD!1< zFfgz)FfcHK%4iuz28I;I8m4T<VzCm26s8)c*$i`;ni-oIQy6PlvKfnYN*Gd@Ygmv( znNwIK8CV#Y8L}CRTvAxG8HywFm>3yK7*g13z;Y;R+2Lvnz-n8dYB_3{Q`jXLkkqn4 z&76_Kn$1|e0;-O)h7G2shNFfdp0S3#h9RD*hO>qto;ii9h9RD%gmnSiLWTwGH4IrS z3mIz|;yFq<YnW=7YglSnYuHk_ds%C_O1Mh6YZ#iD7Vy+C)^IIkVq~acSirlG!G$4K zqLw>_v4%UIFNLLsA&b9;TaqD#$A+QsUkXDogC?(E5gP*oLzTE^URh#JW(sPlW2JzQ zNG&MN%*oGFC`-&KO|?=e0)>1LD4?s9VR}J@5mG6aSfY@bmy%kcU|^+C1gcL6t2DGy zsFH;l2dgk4xejV?=Kufy>or+#G36B9Vg*@Pe2Y1?p!gOKBvV5&{ViT-2886;Up!E* z0w_ml3f$rf19`%=s3^ba7AHKt-{MBfk+=90%hKbM^YhB$vr~)mQgd#x6sM-9++r=s zFG|k1#p#}!ms$iWq;7HNCgznU=EQ@gIn&|dETH257He^7L26ME69WUoEsps3%)HE! z`1o5~>8U00;EJR877NI6w^)ly5=&Awc|g|Wq}^hJNEUH2FfbH>l6Mg|hy_aAMZ6%E zAcznE5kg?2q18k2EyjYCjJG)B<CAj|i;LsqS2Fyv(a*@wO${!|FEZ4x%t_4CFV8Q^ zE-pw+PSp=7O3W+v_4LvAt%TQ(pk@WgNtuaxdIgn5QVa|XAY7~g%3w?!jC_nBEX63q z2*P5Fd`xVNY>aHoY>YxoLX1L8JdA9NRl<nsTo2~!WKfa;83Mu}HV8X|vMH!dmBJXz zpvmMHQUq#Otz;~c1NjZa1`8Caf_%xJSsY&suF>Q3K{dJ=NHZvC8JMc1kZeR#hG8*C z3)o^%0D{_YEetgbSqv@=%^*iGS9!oIJonU+Vui$<97M7%R>;glh{Qvxa)t8Dk_-h{ zwTY<NQEUd4NFZDca<v)*149i%tXwT)4Py#}Bq*l})H2mDWib@-)i8xIFftVK6*55r zrAVHEfx%Cc@fL4UYGG++QEG8K$SO_dTTFTew^+eWC{n}{UYs_d<XBNqWM|L7!0;Iq z?o}E{;e<$qdNw)v$%#3|c6x9vnvA!2z^M@G$0CrEA&~)&7Ep{7Ni#4oSb_Wl3Q`8v zDp9PaqXYv;B`Ctd5ddmQ*Dx$#02ft8wTvZ<HH=wICCtrC#T=jluZB^QA%#&A6g{kB z3@J?6Ohp_Z6`&%Kk)e<yg*li(lcmZU9$`_bMft@F`K2Y`Y9}!{ximL1ClS=jhNfvq zom{L3*HOd=3OP_(^wVSqd*~KdW^ra-aY<rca%zzbDB-b!8&*ZKAeEq;cZ;<uwJ1Mc zlL?Yc6+z<QKmik=kh#SJRsprt5fntAf|h}mhf#&GN*pPekkz3C6e#b5FgVSKFfuTt zGt@GcfcqOQ3>^#$7#A`$Gr2It+SD?2Fk~^+FiA3`FiSFYFk~@{F{H3$GZop?FfL%J zVOq$@2oi;=U}cbGNMSW+U}mskC=@GUtzm3t1glO-VG9PSW6NT%VFcBJS?t-&MH5Om z7I4D&3mF%1bueUcFJ$ash-XV-PvK}`s9^-<n_A`)o(_gA-eyJ@hS;=P7HB7dqlE)1 z&sxHf#h1lj!(78y!<xcX!va#B!YaX#%>b^$*h&Oy*lIwoVRm7NRf=J*Wv}G`<p>sr zEWsM~8V(VV+jylJ%o&&%;spy|6|MleSCglz79N$(`Jh&9Dx|mIQCb3O;TPxUDG<>l z0kw3&sihQ@T8b4iOTY~paO13`C^0h+)bdA^0*K^;5@n!N1WMn*p!fzA+n{14g`pHw ztT6hmWW2=!&bW+Nb1R2UPGW9SN}^pLC^_kXs%7NNTAW$}@9iOSD^v?25rb=5KNMGi zTm(u)U{`^vi*!)Y!v!jOvKS^Z6|w{~K=QQ_0|UcKrXo#{{n*kpmrYJ)aS5ntY8L}i z19qbpQd-90MwliHSAmkZFet1H7(mIph9Oo1l*F?bQ$Wo&#sy3ZL6Oc}!<faA!qm%} z!Ysm&!Xg26vnDIJ5%Ln`Jx!(}aZpBOOGz!uOioP)6?LGFI=F?R$yy`=l4LJR%uC74 zE4jr|4C;S_G!$8Y{9y_r5XBoP{}zGj7Es@}2;|8SP3BuXps0(_FD)r3Eh)am4QhDD zXXb%9JVmJ`iJ5tzQn?^QQ|K03T4qj3Y7w|FK`J!Z5(^4a^HOdx=cML=oeQxY<hCMF zkQ2l}1SmXiu@)ufrKc8w>Ze;AiKQj^Aj^x4LB3$F$}9kvN}8OI(g##KfD;s`Do9Kz zzQve%i>azOiX|hysyK={BfjbuV-+NV!F~jJ?-n;wRC$47KnWCNY%Gjoj1r7oj1o*d zj4Vt-j8aT|Ongj2jBJb||Jj(t7^}pv$F^TGC>B6QfG|G;0|ThR2QK0_fZO1;Of`%t zjHQf4LN$!d47E&<LZ1=TPGwxk0HU*)To@)W#x~S4*D$9rfojAFjIjl^EHx}y%qfhe z48>+OELkidR#8x)Sq(!LYf(=P3#cMUVa{eM>M3DgzyVURkTHuhg(ZbGohgMag}s+) z0T)Dc0rx_N8kTsL8kPk-DWLW_rzArSa|)Lvg9w8J1E`FwVaVdGVUc90Wi8>WVQpqi z;mBqxdXU0h!wPCKm#{D3uK~5dSyOmz7?5~;HVieaDZG*(l24L>iJ_LQmc52Kg<p~( zhN+gLmJ_5OT<UT_O5GZ^1w1tz3mLNnz<%SX;gDcR5$FZA+SwQI)Nn3j%n}5-8H7_f zYB(hrQUrUM85vSIYB)e<NiqmCNHElLf#i_*+%?Q8LXr%%JT>e!+|mpSglc$FSfm*i zGA<CPVNMa2WLU^p%Ui>o0`fPPG(#<44NI1A4O<Ow4IekuhC;JKjT+FHl}Ococ&QqY zT9lSwlv}Kjs8F1lm!6Xf9;L|5tV#s;KtXM4L~)2^`~=kTOwP$KE-gw`&`3_q$pICm zNMjTVy1EL93NRHV8HptdMX9Nv_I_dsv@FlfPf5)Ik46-wmXsFdK}z_X{NiGyF1ufT zUMix`?gwr)y#%#1ip&`p7+!)(zz|KQTg)k$xwqJ}Q&S64GIPPDC8U)NZd-#2@LMb( zFWq8Kttg03EK9#73}Qel^7xd@#Ps;$LU5a0lkpa7VopIuB8UNUZz`m14r;1viWC`w za<d1B&;=2o%nQ!9MW7<M$QvXE%KEpM6O)Q>vE-&E=H22<ttd#$ONmb`zQvlHlb@G* zizT--CmvFUfZ9s8xIoQ?;`p4*?9^K<i76@Zw>T10Qj$TcZm|^PmqYm}nPu^s{NVE8 z7CSU(Z?S*^tVj!FF}QRGH)s4nTz^o(!V2n68A5thMz;i#G82pAlk<y;Qj<aPY=|LZ zbc-?j7AttX<`yR?$l}4pD!2&*CO`#EkuL)SLmH@n0W}jDI9M1t7}=OO7&(}D82K1Q z7@3%a7&Vxsn7EjE7&)MY4G*ILvlOEolK`Ut69=OVGY_K}Qx!jQ!Bdi2pvh4LiqRr) z1haxX0cu}?%X>(qf(jrvO-4UWK2XOiFEKY2)bWas&&^LM%>l7_;^PZT6LX+4?D6p_ z`N{F|;116%j)J26g4Cjt$|7};(V%8-5jZ}<F~**k4=U?Ziomf8^5HFBXpbv7H76%N z9#YzZ0|8X7f_pFEJ|HJ3BqTt^BO{9d10%zKJ{cxPE@lpP4o)^kE@?JKuqsW~q70C= oV9Qv+LyEUJ;T<%v5+obI)_}tL7KaVQb9SI)Tnq{s7FG^s0N1(vlmGw# literal 0 HcmV?d00001 diff --git a/MyOptimizer/__pycache__/adamp.cpython-39.pyc b/MyOptimizer/__pycache__/adamp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ba6bf0c26ff7652d56f9139e6b93a9ef69f58aa GIT binary patch literal 3204 zcmYe~<>g{vU|={fWp-jF9|OZ<5C<8vFfcGUFfcF_yD%^?q%cG=q%fv1<uFDuf@!8) z<|t-HkQj3gOD<~^Yc5+98<@|O!yd)%&XB^A!rH=+!kWt3%pAp$!W7J)$(Cr$$iTqm zn39+qpx|Fnl9`)Xm0G0WnOl&Pnwy$el30?NpQn(VUyzxaqL5aUpR15jQc_TCrLUi! zS&~tjq?eqZtDl^cUzV7dsSh?qKPe|aNk2ESxFoemKQSdSw*W%w6;yI@1tb=v7Fj7I z1n1<JXXd3Vxa61TDU@WSDukzIre~BW_~jSnDtP847o{c^rz<Dfs4tTv3{<o0M3b znu2g)ae{&_+$o7g6`5sv`9<maiAlx!Mg|6EdIlCohUQ$(`6;PZaAmk$%*ExLUr<>D zc7R5*rh<`yk%4YOQEG8&UWtOAW0-4@f^&XRfgYEyXNW>hW^!s?aVpnKP`vnQGTq{Y zhuSTUqSV6D%%aqkWRN5>W`}Z$Js21mQW>HcQy8L{QbgJr(il_NQ#e{UqnJ}TQ@C0f zqF7S6Q+Qe!qF7USQ}|jKqS#XSQv_NVqS#YJQ-o54TNtA_+8J0FqBw&YG{tVQf`k0| zpOqgk`6dO}uVhL;dj<sF*r#i^i`_i;{k{Ejs~&?#eU+|uMT`s#49{oH1c9CQZpoZb zH-RWF1_lNYpBa?6bQl;IN*HRGYM5(SYFKO7YS?QyQW$&LYdK38OPFdHnwb_b*D%&_ zE@Wb4NMQ(O&}8z{WWU9fQ*?_pDYYcA_!e_&LGdk~a&VN!r=%t)R^DPwNzExqyv0$T z5?=&LQ@1$sQb8#szf6<m7E4NIa>*^$;?jcDq9RZlzr_(BpP83g5+AS0b&I7qH7D&B zTR~z`Vs3E}Cj$dR5y-hk+#nVYh~Nbgd?13Kfq~%`M@nj1VrfoE@k+*9obmC=If=!^ z@$oAeewpiM<maXam*f{2>R09@=INK`7iAY0BqpcohZH5|75jSn==)as!*T&U-4+Qk zFfa&$f`$VWSWH}uZ2#Gqs`$WxuLpB#G80G=ia~5v1_lOakS>t+8pZ_-HH-@wQy3RA zrZ6o8<vnIUaFo3K|NsAgO{QBcWtpkvw^)iZt5R1o6$yiU!dOux0rD1@kOYbF#wTYa z=H;d4#Di3egG>im#lTo440R2%Jcg4%feUsLDBM6!0u>xk$Gilkk0L1s1_n*WB59Dd zvLFI%H<$oBfIB`Xu`;y?Y&}SAG1PiNsPzchWRU-%!NC9uI8fpOTOPu|z)-_d!`RGN z%UHuw!_>@J%T&Xd1uD6idKnim)-W$*tYxiXLK9)DVXR>iVaQ?vIl7mzmc5330rNtJ z8dhnB8a8QALS*(UQea?USjl*cIVCd}l(O>ji*j$VB&MXq-(oH<&Ar8(m{hFE0S?_F zIZ%KyR)Q1KElyB)#;29$CEsGJhzDz9tAx<ZDfuPn@gNkRoL`)oml|K3nVXrDSd>{( z2@0)ZP`ENMaxii*2{8&W3NTekK*I$?c`~d_00j|98z_jt!E*>yxHHzWr!b{6rZA>6 zmN3*XW-&H1rZ7q{h%ksV)N<4?WHHq+WHF~PxiG}`)pFJ_EMTbt1tBQ6axP@4<*MOY zz_yT~hNXmk0Y?pI4PzE-GvfkIP~uz2D9w<~P^1MCgUMuZr7)&2_cAWvuHjh7xPYgI z4Wz1;yN0=jyN0udyM|eWp_ws@7g?1Eg9Sq^a}8@Pdku39dm0lH10zE)L!lOg7tEl^ z0*MT8iYih9B`dyL%sHufw^)laK-GXeD0L}<2whM_u%slGB;I1qO{}=ZlABnPaf_w6 zu&AU+2Bbg*M1Yf|Du|^9GK>ctiMJRFZn2~nC8peBFGwvaDJ@DWLX-xe65<w9c}f&_ zYDGa}UP?SDZxv}WFfc@MA>t#7xgb6dJ$||43ySiyQj<Z|pE@X)iGZSsk%du;QGiK^ zNr_Q~QHha-iI1rYmhchk{ZO(UDD#6H0*=}XjG%fOl-g4mN*RlUY8aatYC%!S;KDG0 zF?K^Oa}9F}qa=j4pq8bEC5thIp_HM>qEHRQ&SI)%tzk}Kl4OWss$~NOeG0QABtbCO zut+l0a+a{va5gihuw*k8O-f;@;Y?w*VJKl;z*fVO#k7#IhBJi?jn8hw01}l1k?fKT zOboSLwcH?mtPGM2DIDeu%nUXRDI7IyATdb>VFn3?TAmb+8g?W;Zw=Q1c2J^OzyVSz z!BE2rsvKqHi`F9an#jv985XbtN^W=4h@?gd;68A`YpaMkdF650Ze8s-#ENrr`t zwR|<qDO{2aH9XP`wfr?46BvuUYS?PHM8GZpnE=WFHGI+xwE}q@HCzx@3P-JA2`H?1 zni*3#vRNiD7IoDyEZ_xKUjjAEDcq6_HOwhIU=m~xTMSdJK&@a6a|*8{LkbI6#|}^k zEnrUpxdr3`u$%}(ji3mF1j9l`Mh37vM-4wWLyZ8)cZF(&5;dT5oX@WalsX`_5~v0O z7wflJic3-pZn39U6vQW%rQhO$GUAI1!G#pK*nw1Mh!hJd2Q>L^@qjD4`1GRu(t_e5 zQ&5^W0Ob+Z;*!LY)FM4FyClCTIpY>rRccXwaePi@cIqvb+|r!*B2X5&#gUkjk__e= zgLE5#2z?M?0m?f<ki5g5oReQ%T9kT=B`3eQ_!cYJz9JKlGLW5+>f8`Q85MzjW($(E z0uk0AM+hWkCKksh=NA=0vWnp?bP*#^)t{H2dy5kk9FQDn2GS3%ufg>Sn2-dCSu!v% zn1gbg6EwGRFtRc6F!C{SF)}f6F!C^RFmW;SFmf;oFbXgWF$sWZCNV|{CK*N^CN@SX zCJts6Mgc|@W<Ev*W)Y?;NV`@KWV0rF5vV>YVguO&Dtf_@R0Jw?!GxP8qn{=hsNs>9 zn420Oe~T+VJ~uz5GzY}yiH|QVP0WGFfLjAasHImCC|HX?&cDSAZE+;0=H$f3Bb6&+ zAe+J71l7U<Aa8;Q4n_t>hW~tQ9LyZd99$d%;0BN;=Pma5_>}zQ`1m3_P(X8lV_PpT z?-o;DUX&<=mk(;G=z&R4tGx(R6&HbO*&=X!gRDWaA0EuNIBXz3vjf%m#h}I%3l}JK I1sFM)04Yrg?*IS* literal 0 HcmV?d00001 diff --git a/MyOptimizer/__pycache__/adamw.cpython-39.pyc b/MyOptimizer/__pycache__/adamw.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cae0384996e05f7a7498980d9d36881eb3625d2e GIT binary patch literal 3827 zcmYe~<>g{vU|={fWp-k!FayJ55C<8vFfcGUFfcF_n=mjiq%cG=q%fv1<uFDuf@!8G zrWD2$<{ai+mM9iR26u)OmK4?&h7{IRmS*NC))b~-22HjqLj}i_#N2QN|ALat+{~)f zA}-I|f*gh9{DREX6os^+{9J{A%8>k`<P3$}#Nv|FBCeMpC-`YH-r|JoNd^faV^%1q z*n)w9A(bJDF@*u-nkeQJzIKK*#uWAxjuy@+7BHJLg{y@liZz8lg(ro#g)xe)oq>fR zianS?Q{WaW*l~%6Sr`~V&PvTq%_}K}x;QZ>J-;ZkBqLXki%UU4K_Mg~RUyA9Gd(je zF$b(f0j@%!Jh51zpeVl}zc@8TAu~@Q0c5U~f}?_OYDq?Zib7g`kwS1uesV@)aY<$} zEHDyFGV}8i;`QLhgLH)}lqD8rCgzo38tRgooL^dylbWIso|>7SQKI0Inw(gv5R{r; znv+<RSp_!_Y@lONdMVgHR$wm33k8WqiMhoJ8kr@jMTto{shU;_Fs4F&ngUolwIsEu zSfM0eAs-rJ3i(9}DVfP7#R@5@X_<MMdFgQT;AVm}DWn(WmlhPmC3A`tG}3bN6H9c! zM&;)v=4e_e<fJARfy_`SN-RlL&`3#5ODxSPu~INh)iu_H>rF~6Ni0^-2m$#v8fplb zibXLjIX^WmEi*Y29BZYZum=Tna(-?>X$iVJi%Rpr_9m947A2;q7K1#TUX+*u(yWk} zm!gnaQmjy1SejUb>?{ofJxd)013gPiOHG9HQVWW)`K=_iC|4mdB_%Zl6n-TcsR}8n zdHK1Sd5I<YMIfQf+=8O~vQ)S?LE)EInwwgbnVguTP+XFjl$n!RQi&QY76^OG!O<R{ z0*-bpfeVpWfXE{S4yu7h2m=#yi$P(applfHpOc!HhZ4Hw8L1^1sbGhf7N;tJ{Nv~w z>|T_Z0#Bm(X$mD7nTR+Cc?gj!z_}Zwtst==wMZesKM!n(bADb~YEgP>UUDiZ5<rO_ z905+LmHBxo3GpbwgGeE6i8;lonvlGxr>77P(icNq9z*1SjFOUqVk>?9#G;DKGQIqw zbp6DnVtqpsLnA#iOA7<A)gZ%(%4!6SH8(WWGcYwYFoc^+R1jdAn_rNcSCm?onOd%w zms+BqmS0qwYoD27TT~fXYM!PW6a}{6`Ja^^FZm_~*so+tKYIoQ-q@#Ww~O68_x-*7 zbE_VMM}3v9cF(It*112dYO!}qW``D|Ac}*5fdRy41{L^gKm|Ty4O2E_u~-R13ey6{ z8m5Je&5TWqDU3BN*^I?HB@8Ld3m9uy;G)baERqZ?49pDKj72Uftl1335qV6E3?&RH zYzr7`z_Lhc+2Lvnz-n8dYB|7a*(DhuYD<`Em};17SZY{n*itxqS!>x#m`hk{7@C<D zu+}ivurFj{WJqBMX3*sFO9y!ZgsUVy^U4x)GE-2?8Y_kBT7*<;L2+hIex5>EVoqr) zNUDgDfq?;P#s~W<RhSu|k_b^SDkPRDWag!$Rwx*NG`$3ABCOUBq*jyl7E?~qEmn|$ z#kZJK3yN>?K=LajKi^`9W=2i6TU=ovXSo&?<rm#zOUp0HO)R;^0xE-Vu@;vWq!tx{ zBIy=Ke0*kJW=VX!CdVz7;?$h9TWsJGyts&yfq|h2RJRv#gIGKuf|r4T;T8w9ye(eI zc#AVWJ~=0`xHvw3CBrXs{fzwF)Zmi*B18SkoWwl+^8BLg;)2BFRQ-^m#Jpl(Pal2X zN_d5=pO})ETdr47StP{3z@P$3G7_M$XX0SwV+3I-Mj=KJ=3r#`&&E{62g<DBdN2nk zGl3LCF~~2V`W!?T>o71dlrWSqf}@zJhH)VyBSQ^C7PAXOtUxVO4buV^P}>6(?S*3C zh-daI0{P5MleI_?<V}z%w|L^?i&IOAOA<>`<Ku7ffI}r7Qa#<`Do!ndhJq$@kuWG! zSivf9v4T|;iGhp(`KCw`#FAoQU@!uCkPYNsMgc}H#wsD0cM!Jtq4){p6i|k825CIR z$iR@oSj$wyn8H}fSR_=#*vwGNRKt+P;KDG0F?LZca}9F}lO%*Usg|XNB@3jIp~#|8 zt%fCwX#sNzNCk@|Lk)8Zt0aR6g9HOOTv)OgvzTgGYgn^bin2;rQ`kTmw1zo_U6P@e zt%f0st%gOCp_aXby@tJ+F@+<Wspv=wM-6)lrwv01#{y1}*$WwK*i*RB_}n%OHEcyk zP}K0)Fo0A_f=F&j1}27Dj#|zdwxUNh>?u5w47FS}%qhH*3^7c#+_gL)J6IVc8B&<d z8JHPt7*d#PxIto)48jZ&47I!|%r#&&P(EJ`#{%XW?uCpCxIii;7;0ESF<isBfVqZe zA>#sWkb0144eLT?Mur-;35-PzB^+5iHC#2EHC!oty&zGp1-uIxO86G=*6=N4lxA1} zauGkMAgSfgld0jvEek1}YIvm?Y6U=HvVgxvfSaL)qlO=3TcKK^L=C7dD^Rs9AhigT z8;cbZ6^b+S(sNS5Et_0u%S)lSB((tASOissSQ=m;;pCkB;!;o}BRMfA2h<8fsdOC` zVDcpyi6sg}si~kMDzPNB7+iPe=BK3QfD3z&Q7{Z@z7(aFlosVdRORFs7eiW=AQgT^ zBA|o}N{^6|0aPATd4bZfLP~yWu|i&ci9&H{K|y{|i9&HfVo`Cb0;p90Y8e&lC=}$R zCKjhEB<JT9XQreUDFlO6fOTi)6_=zYrhs!GxB$Jy0t$m$?5Pz6@rh;Wx458;_~OD_ ze7T7g@d!S+e9~mR#hhPS0xqUB1&X9Wxse5A_brx`#FE5YEV-pQ@wYfLi{l~gy2TSz znpcvUn+h%)i)2CSnR8O}Zn2htx)8Uxs#1&ci{o=Lvr})eB&MXq-{MG2Nl6Bq!kn8} zaf_w6u&Cq~OKxIG21F_)vn*ay5K<1XLxcVn3n&1K<Upn~78HSsH;5NO_7y3B0#gx0 zD1j8Pf(mp)NV#rwOCTvTu{b_Czo;lR8PvKr#1JvM#R_f}-r@uWReUk1PXsP)zy!F+ zk!N6FFb9<!ZlEHEorRHuk&TImk&ls!k%?J`nS+syk%N(giHn(sk%LKqNq~`sNq|X+ z5u}cXNq|v+QGi*DQG$_;Nq|X?Nrs7sNsOtA3lw4DdLTPBS&KkDyCP7bTLdcAia_<g zn<k^5CI_fe$xF-y6`Ne~@wxdar8yurPkek~X<`mkhCMz$B|kYn9$ferfeOeXkdtrm zLTi-d)SR67_#&`}zy!ER1$zxtJadA)21<h*j4V<Nj12$z_;^^sRhlMSksT;t4M2nz z0|P^pC^%mAz<m)22@0<wQ2hbcfn*8TCXgk!IBXyew*ytr#T*O_3@pqXj2w(S%mA-m BB7*<` literal 0 HcmV?d00001 diff --git a/MyOptimizer/__pycache__/lookahead.cpython-39.pyc b/MyOptimizer/__pycache__/lookahead.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ae677bc3ab72ad3f7a4591ee097c38fb2755106 GIT binary patch literal 3134 zcmYe~<>g{vU|_K1NJui`XJB{?;vi!d1_lNP1_p*=3kC*;6ox2<6vh;$9L6Zd6vh<h z9Hw06C}u{G7)uUIE^8DkBZE6b3Tp~m3quN9DswY)6k7^YFoPz0)h`8~{QT_1jMT&w z1^<GQ%-qbX)FOrOqQruN)FM4D&)kBX)ZEm(lEjkC{5*x+{FKbJ%+wTxw4(f6D}{`b zl7eC@ef{*zl8n+Mz2y8{{luJtjKq{2{T!I7dIgmw`9;YYTmgv%sYO-_30Um2Qpi>) zE=esYR!GY)Do-p*(NQo23n(NdCTAxo=)!GEEUL&X)5|YP*H26;);F{?FxNA%Ff%mZ z;_^sL&Q48HD9KMxEy+kNQb?*)&{uHIFQ_caOwTA$FfuSQPzcH|E>;K!@pBXNxL$(7 z(@&G}7AM@}nvA!&Q&Q6sOLIz6GLuV^K`M|jGn7;8!N9<f$`Hkv!VtxjB9OwI!qUPR z#S97ID3%oV6pj{#DAp9g6s{ER7RD&H6rL2`7KSMHc7`;@6uuPx7LF*6b_N!PD9&I8 zO`%(y@IZMEvEhOJOEv}uhGbAQLqr%DI2afhK!VPoXkP+~cE%be76xX9Y{nvy6sBy( zVwF55MurlG6y_SH*$i`;ni-oIQ&?)4vl)x6N*GdDYnWl8C5$ypHOwh&DeS$>wJbG^ zH4O1gH4IryS<E#I@ys=hSu8aS@hmkASu6`!YgiUCGBT8~l`uAg)UnsFEZ|tk;KC5A z5W`f<TFX|$kj3o65Gzp2Uc<hCvxc>XZ6OmQLt#&0T?#`mgC>VxI?Mx5#s~W<DbKvJ z#GK3&h2ot2a)r`@l*E!$g`&igR4av-j0_A6Rgy3T@N@)D55-mrMId=i=3D%5nRsxT zyTy$V$iBsg;D8dACf6;lu*96wRM(=S{GwYdpvb+&my}qX8lMkK^tX5l5{nXZ<I{`s zOACr`aX`~S@hvup`-(u3e~Y!aB(Ws*7HeinYHsl@uHw`Zs5(uKTP($?IcY__3=9mn zSRrm=%r4>waar;bb5n1zLzUiQ1>3Qbp-75>f#H{#enx(7s(xioVxE3^eo=ODL1J>M zen?SbUa_yIkG^juJb9wzs#_fK@tJv<CGqhf8;e0kF))fT@-cES@-gx;vN3Woaxn5R zaxn=oaxqp3!BdzXOj|M&EMgcK7+9egoVaQj7#M08Qy3*d(VSPyRKt|Tkit;PP$W^v zQNxhMSi>aAP|KXcRKr}xSmag0w1ByWDT`$xV-0f(lMMrilw?@I3K3(3@@rXYShCna zVl^xhAkXB$n!nHU*BVz6XZ$N^3?tbVuHz_C#r#SKa#@kym=X{kj;pe&%tSR~58 zz@W)~izU4%F{MZl6u_+DGUgUnYHmSEWqeL%cIqvbl*E$6Tdc|X1(or)SQ1lG;x$=| zL>L$tiX<5r7^2wH5{paX3vO|P?TLp(D+|c7B2Z952x*Yn+>rQ>2g%8T@*c<#22jv4 zGJykGfKh;vgGqv^N)!>uNP3b%DHaygpfCqzYH&D%EUID10);YPEn^MC0>&CfP-ZJ+ z1BU{mpC(fgAIQaOAP+MkTn}*_NP7|3iG0PCdCBqcG_MVECCC;AP}EBxT!pI44<!YF z)Pqt1*kxXz6u^)Lbx|$2aByLWHL7K<VNPLWWsqb@VKQf6W&p=Fqa-MfnI##-7*bfW znTmvJ7#1+sFoS&}R0FDKSpAAX)mafc$TOTE0u&~iERf8?0?MLD9%TW0^cH(^PJVG| zQR*$0oc!WqaIB&EyGVn9fk7YSSsqaMGVw5SF$yq>F!C@~@gaG-B((q~4nV#H`4;R` zE4WV=GSq@<Po@?IP*J#m8C1wMGuATIFs3kr(t8bK3X3E|ElUk^4O0zE3hNx^U<OS# zKNO=uWdz9hV34D<7#J8b8EP0}m13A`8EctJ7-|?x7_*p~nTixjm}?lD8S|KG7;Bg& zGZnG~Lka^;mRnqz#hH1<C5d^-sYUXj^v4#Gnpd1(bc-o7MU(j!lb%757)Y403@tQn zaoK>%wdDNV0y~heia~y4V`BSXrGXq6i6yD=Aa!~+Ir+(nImLE*2yNiZ1da(!W-$8} z7bw+3atA~l8fRSK3<EX@tguK3ltdvJ>=r*bZy~8L1m!p`gecf3H0Ku?F)%QIS~JCv z{Kdt{#>DZDjfI1egNcceiIM4Vl?XO}pd?~Y0RgHbLGe($2o$y8VyuQSg-Mbjg?SDW zI4OgYH4`X#rm)m9r?A#AmN6DNl`y5SHG^t8<`R}>#u~;HRvU&C_F85TOA=II!AxK& zVJ%^+VQ6MrzzC@uL0LG9y@ny4y@t7HUJXMQ#{y1J;z6n*+2J)L7rch-Dy#u#a}Ga9 z2!e9DCKH0~Rs;&RDjT2t#FWgubOl6sf}39o@S01ZJhLPtzqACYDVJDKkdv93q8Fmc z4asBtIr)hxkaz?~+AY?CqRhM!aHUoR%00JOi%Sbqi;6&r1`<7>nhPAQMZ%!C<pAeO za4Z!WgW{79ku>7-Qp=<Gk@%nlRwNF}=n^0TRN}5=yu}$GpPZ9eTpS+{iF;7lRb;}z zz~BbTE-K(W%*M#}kBx<gg@us=Tz|1JK}0~c7z>jSs6dlMBm*>GX)+aoa$gZB31~9< zX>x%oNl@JzAAgH0K0Y@;r8FlsKK>R@e0*VPVh%*6NCxCYP!U_C4q|}{6^J)MPP@em zb#!uSP7ciB;F8xIq!{c4P?&+juoy&eFmefS@NuwkadPm1LsOIe7JGbrN`7*DJUBx~ ziGnLDJ#hR%NKm7@2vjxP;!e)b$w^HHwYG~PkpnUu5?o+g!34<GTO2kJ|JZ@b(qd4` LXJG@?_&m%2k?Qas literal 0 HcmV?d00001 diff --git a/MyOptimizer/__pycache__/nadam.cpython-39.pyc b/MyOptimizer/__pycache__/nadam.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bcbfa6a76f3d6e00685eef8c40c0840e26ae884 GIT binary patch literal 3030 zcmYe~<>g{vU|={fWp?5^J_d%zAPzESVPIfzU|?V<)?i>@NMT4}%wdQE(M&mvxlB<^ zj12A!Da<J>Eet6vsZ7nxQOqff!3>(LFF|JbX)@m8^e-sM%+0JyElLInA!Al3r&x!9 zfgzP4iZO*DiYbLLiaCX+ogs}eg*}C%g)@pJg*SyOg}a3@inX1Ag&~S9m_d{87OP)k zN@8x}KV}97&)kBX)ZEm(l41q0ph99!dVW!6Nk*=MMxsJlVo_#dUWr0}nu23WVy;3` zVsUDULVlitUutnlYEgcfLT)}tb7`)o9v7E_f`Wo)i9$wVu|iU6YMw$tQGP)_NJD0x zLV{;ra(+=keo<mcW?ni%zb{n3LS|k`KG>3k_;{#Ujz#IEAk&Mjz(#|NFGws(%q>>X z$Sg@MN=(X0)wEK8F+q+5OQ)8k78NU$<SXPu15Y8pNFgONxujSjB{eNGFEcM4ZXVoB zkS2xnqWsc=Vz^{Zk%C5APJUvE4%n#tyu=($D}|iY#3GOx3Pp(}sR|k?scDI&IVDyK zMya~SnsB{IsU?ZU3K}7$1v#nFP(#2}EQ(>t`Kf7XnaP>p2q*=GRa$<LLUMj?K`AJl z;KqU6SyY+_wl}dXwJ0$?wHV~#^rFNRkY<I%ycC7Zl46D8!qUW|RJgv>f?{lbC`m2K zRY*)pNlgKTQ%OduLP~00er{%7Vo81xNGLP6peVl#VGhXtywcp%qRiyP9EIYN#H7re z%#unZ?;ED-S|Ge%o|>7SQ4*h$nw(gP%>fX31*p7_kwQUgUSdv3r6!Vz1_<MdlQU9N zN^??i8wg8>3NWQGv*Gb$0*)VW$i&AhWR#Q?Sn2C07aJK_>J^tH=B4EqrRb%ml<F50 z<!2cg7@F!En3}{Fr55BDmFN|uq@ignFE7_iF4il_FUrp=$p>q!ODoDOsnSnODlRD2 zhq(u=6|5H&g#JaD>6v+nIXRUIC5hRoc?xMo`MFkL#}r%X>!)XyWRxc9CFker7gUzy z7bRyP=+c}VeM1ui3%FT+{vobb3Lc5Y3I+Kksd*)ti8%_H#l@wm#R^HKB?>9|sl^KA z`9;|Z<*7M2pgdcgpPLHGlSw(Lxy5>5bDsZM`SFr(QiA<Tru4ICK;Vsix^}zR&2!)1 z+dsGJF?iHh>1tQR$iTn=Ra#)5%nmJiKokc90|SW93@hYH7;2bmm}^*SSZml)7<*Z3 z*-IEpm}(fBnHDhDFxIdyWMX7UVF+f>Wb)Hwy~UJMbc+?7Rc<k-78Kv&fusvaG~eQb z#U+HV$#RPYl#Oq(7MB*J78QZ)yu}e8pP83g5+AS0af_umH7D&B8@LcCE&@5Nh?9YV zp$Jq)7jc7FJPZsBw>Y3lqIf0aEzbD(<ebFf;`sQL48P3vGxBp&gG=&@4D~B>67%#y z!B<?6n4GE~Qk0li?Ca^H?^_8klJ)aIrKny(Wf4CE1A`1G1VDB&FmW)l{AXjT;sY0` zdN2q2B{P9S42ne=7#KjRok7~(FfuTtFxE2FFs3k;G8PHdFg7#PGSx6-F}N^HV2nLe z%Ur{p!Xydd?Wtv{VaWoiWGJ#IR0FZIm}*&T7_yjaSR@&0*-BVy*qRwrn6sIRCZ#af zu%)orFr<K@gUyDahIIk!LWTuwHLMF67qBm6s9{TC$12KU1CnEtWT<7YVNT(cWT@o; zsg(qYa!4{TG1PL_a@8=Wa7i-6Fx7I`^3*V=a7!}O@`CJRWsqb@;W1}mW&oQ4aur_< za|*8{Lke>?^905skUKaaA~h^oObbEo=T2cw;Y;DK;j7^gVGw2zVUS=DVW{P=fs3c` zh%=zb3DmIH@QX0i3PR`_0TG5;p&G#&wiJ+0ITvu&a4%$Bz?H&U!!5y3!wO1zHCzig zYj_qiF5pgKt>FQQ*03&QW@M-Vsp6^O6K0TLkY=bAt`T0qvyfo{R}Jq%###}O8yE1_ zFsBGeGH`>!k{{%@6xJF+km?%N8WFG=ELltlRUp@-u+|8HRd7O76si?U)G!4zXbM(s z3rH<W%P-0;R!CGR&df{CNd?z2xtUdoC7Jno3dJR<1<+~`RB&La3PHljIr+t<MX3rJ z$%#2Rp!yJ{a&c6E$(Lj#mM9dZrj{k<lqQy>7At^?{M`JM)Esaz05S@OK@~|+YDsBP z9z<16elfJL0IBdR5(MQVP=RoZ1?0wbNP7NYe~T+O9+qEkv8PrP#3z=e-{OKY;)@H5 zK$#DbAzy;b0Ov}m5fALscgVHA4&3q19-MbI`EK!mvvNG7TE4{ss_bsDq$HLk-eN5- zNi0dd#hjCxcZ)eMwfq)WQEG8!Rcd@<ar`aTs??(V_*=}0DJi#Da!YgKZ?Pn%q{QFi zNK8pd2J=~q3yVsi0x6kg@tUGV0-&OS9U5S_SU}#s#R|6h7Gpt?EXddLAVLB}fW0CK zVo8Ctfr~IhNJ(Y{E^clymP84_>VWv<#N>?B_>w3gbdh+&C_Y56#pk7#NAaWZD@vjS z5J4DUP?VXQdW#iYSr>td5C|a#a)~$t1A`K%H1PvP2|EiT2O}F34<jEV7b6oR2O}3V z4<iR72csCH7&C|#Vg&I87&(~u7=@Ua7=`|`FflQT{bym~0Ld_mF-kBBF!C{}F{&|D zaY2eQP=sr;7J*u|MWFU)5gW*2a72PC0Zm3fO?FU;oR^rJ8Xtd)D?UCqA6#w6-{Ogn zFDy;Wfy%JQ$EV~c$H#+9&LWV@Z}CD);^fqvocMT%zra2e2H64jA*jFySz8P$&^Z`c z#26SE{`0YcD*{cHTdd$3Gm0Bb>46)6MW7-VERSR$*i?{#w>WGd?z01Vp%~;X4n`hk E081N64FCWD literal 0 HcmV?d00001 diff --git a/MyOptimizer/__pycache__/novograd.cpython-39.pyc b/MyOptimizer/__pycache__/novograd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9f49e4a9cd08791de1aafb9c381117c17a9d9ee GIT binary patch literal 2259 zcmYe~<>g{vU|={fWp-jNHv_|C5C<8vFfcGUFfcF_n=mjiq%cG=q%fv1<uFDurZA>3 z=P>0mM=>)(#8{$O+!<0>QdnCUQdm=&o0+3nQ<#DoG})^D`Q?}8yB8&<DEJqYWaeg8 zr55RN`4?rTXXYj5C}idq<R~OnD)=TACl+U9DtKq+mKG%{XgKHRm8BM?r{*Q6D)=NO zX<8{{l#~<{Tj}ekXO?7?Cg~;T=jtaT6zS$9Ch6zpm*uAyC8lr%Bo?F=St%q0m*gjB zBo>!sCM$qkkeQlSqTrial98WMtWcg=lA+*}SeaU+Tb@~*s^FNCSWuE#ma5>JpPQOj zQml}cU!)LHl$e>9nU}7RpQhlFnp&XXms(PuUzA;(z@?y|pr8wPRbo*^W|>}oQM!I& zQn9|FrGcrQp`nq58P`itc=~BF-r|INKN%#1j9H<aVhsichE#?q#uSDqrWEdWhBU?$ zwiNaj&M4*-o)pd$t`^29mUadfhA7rx22I{u9I&8udp>I>2<)_9$&{W32Jh_CWBLz$ z)rfg-|J<s_;89<tt6dQz0|P@cJJdE1#lgV9z{bG9zzmAy2nGg*5{4S)8kQQ?8nzV1 zUY1(+62=my8ir=31<W;!HS7zS7(wz3@hmkgDNK?KH4O2rH7qI2U=~{qYYjs@dktF+ zLp(<fQw>8rX9`OVLp)asLokCTtDh#zEvB5JTdYZ`C5gqim{SXiZ}F6;W~OJ9#HXYt zCsu-8_Y&kpH%;zaEGe1ECAV0MOAAtqia^f0#StH$nU`4-AAgHEKBwpwTRg}hLnv)@ zizOapB6ECs$}RqQP%4N|EK4m)OwY_qza<)<3y##%T(A+Di8+~7sVSNqw^)i(bJA|H z6(klV<`zfsp=v7vg;WtJh>AEF7#NB`dF2*IN@`kSX--M;O2%89@$tzyiN(e7@hcgA zS?Oow=cWdi<QEz0SLP(<>4VcrL1J>Men?SbUa_yIkG^juJQ2VXt6o865g!8s11~5F zgg}wPD8MMh$i~RU$i>J8f=pF{u!z=!g@Io(6G$l(^D{6ofE?=#GQ)%klvit+Y8X=( zN*RlUY8aatYME*nvKWdt)i7i+x-i7{)H2sFr!Yz~xG==l)w0yEWHF^Mlrj`q6>8Nm zWHHyUNHWy2f<&`eYT0VovRI3vO4w4Eni-oIOV}51)UYjNOkoygsAaEVD`8*2S;M}N zv4$awOPoQPp@t!gyM`r;MT9||p_U_srG_<y)rO&ly@oY~&4!_dqlPtw-G(8Bp@ubu z!-k<yt57b5vxXraq>s^sVFF`p3Bqj?7-LhATz8}J7?SH)Q&?FUBpFh;%o&&&Y(VC* zN`go>Fv%{-z|6qJP|I1%UJJ65Lz1DE3+5vfKXBJ@f&4y!vB;{1J5Q{JtA-(qM}(n< zJB1s>l3)P)gsX-vg&XWwE`(n#8EV;T7_xXLFcvj{<ZBqR_#_xWqG&cSWbxFni!i`U ztl<Wm$qrIq!y&?eKNP{?2yzu)4W}@J1Vb%P3W&!8<?+@qr|?KJl<+U$tl?eAD8f*~ zBh66DSHoVz!2}KuaM-ZhFw}5C!W<Mk3j}KTxEX3VKpDF5SmBNuP+86ES0v8Bz)+<H zOKl1%`KiSUdHE#@#ia!W`9&oP#RZ8)#i<IQd;%&=i(i6LAvojSVks_3Ex5&4c8f9h z76(W{d}?lDND(&!14B9_nSQVb=O|6VB2fkghFd(~To<2SlwVp<e2WF7<rZskNn%Or zEtZtTlEhn_nZ@xC3vclRmFAUX=BBz973CM*VlBxpO3t{&l9ykUdyA#Gu&AU6lyQoL zLAg{AM2LV00gx>$xrrqiw^$NWQsOoFi^M=e?8!O##id25w^(xWi;Hiuf?aTnv7kr= zq)Qn@$b$$45FrO56hYdU(v5BjBxNQR$0z3(6{RMZWaj4?Vu%>s;sgamd@-n^lLl!9 z6EYwc$Sg@v7L5Z12qOz42O}F3D5LW*axpS7a)J1a987$S3XB|J$iv9S#KXu3l@(x; zVdP-qVB}*IU}X8v#K`lXg-HU$XX0a&VANvdW8`BLV-jNIU=(4j;)O+`9>_VGtZqf1 z%Df0v)f9nBd~l>`GWuz<feO~V#N5>Q_*-1@@wxdar8yurPkek~X<`mU23$fFf!u$K z7h32gr{?6u#}|Qp2PQy1D+2oyR9y3b{0Xv>gOQI1T!3q`-C~cAPsvY?k1qmMbWx(< zXw=IGH47jlD9nn$NlF_OULbRjECAaAvfvhn4aAvtpxUAsR6?_`aWL{Qaxen`8#OwT literal 0 HcmV?d00001 diff --git a/MyOptimizer/__pycache__/nvnovograd.cpython-39.pyc b/MyOptimizer/__pycache__/nvnovograd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3528945920d301895ec14ba380f4b1132b62478e GIT binary patch literal 3708 zcmYe~<>g{vU|={fWp?6jAqIxWAPzESVPIfzU|?V<Heq03NMVR#NMTH2%3+LROkqr6 z&SA=Bj$&qnh_OVmxHF`%q_DOyq_CzkH#0}ErZ5FFXtE`mF)}bH_?2a*WF{*3<(K8V z7bT`B_!pF9=4Mu<7U^;M7iFer<|XDRWabv+C?r+F)TR~X=PGz578j%zDWp~;<`(3n zT5%~TC@APEWR#Q?6kF-*r)QRAlqTsV=jZDCg?YMoI_kTmrWW|5CKlyo=B2wrG!^S7 z<>V*n=Oz}Hq!#H1REFdiC1>ad7o?^pX9T4t=cngomSpDV=|imI3P>zSEwWNb2rkJ_ z&PXgS$xK!Nxi~X5uSCH&wIm}yrC6anvm`^oC$TcMNVhz*I90(hC9$9+vn*A?H$OKu zucTNZEx$-1q$n{H<P3%UGzE~e6#P<4%JYk|ixbd%kyuobS*Dj?l&+tcRIG1kX<({n zXlP_%#`O{u6@HqGw>aU!l?)O>#%xedu>}JILn=cQV+unQQwn1ga|&NOLmFcWdkRMj zXA}#V&6&d0!V$%q!k@yE!rQ_a#n#Ti!VtwC%%CZ7i_5PJ7CDI;%nS@*H+q5s9pWR9 zWI8D35_8h?i!w_xa`m{tY8;Ewiy`p>ikpJOqQu-{1&z#-)S|?soK#IK1sD?)4q)lj zlGLJNg_3-Qd}tz2$S+bz$xJROR!B)r1BV&dWRQ6v4A-QPUX)*2Pz;yMDN@i#%gIkH z(E%HkpO=`UX{C?@jedoq#FA76jg-{1#L}D+D+R+;U1LqS-lWu$#9{@FkkW#j)M%(7 zU@8{Hu;l#Iw6x4*Q0x~gloqF^fTK1!KewQ?Br`7^ZXC#+MWuP*_)aWKElNyJEe3fQ z6l5UH3W<3s3YjIv3dMz`iABiH(lF4oG}Tcs(6hAAL^v<CpctFqN>Yn*6%tcYQd2<T zSCWyckdm60pPQMNSdw1^63WaiD9SHOMR8YNX>Mv!W^!VVLUBoAQf5wONhNBqSRm{z zPt8ovD2Y!=O-`)D<^YJi0#x3|NTDD#FEOX2QWMEU1B7uP@5e*JBQr1EN&&=&1P3^b zkV7yrw-_X=pplfHpOc!HhY}6t8L1^1sbC+L7N;tJJnrZl49b@ZWr;<ZiFqZUz%I$i zL?jMSjKI^X0=O^%X)6F_Z-oT^Jg_0o`FUljMd_(|$*CX%98(f=!7=5OTA81hk`Rv) zj)<h>mY7qV3i0LhKPx|8@=Xe`U&)l71_tl!(_{J%ebtD0Z~xq?$KX+4rK?>LBLf42 zTQUbUvw|p21_lNYpBYrVt^*aXj5SQzjKyLl3@J<t7;BgoGBz_dF{Uuquw*k9>y$90 zFfU-NVS$S>r?5yeurM$)WHT1Iq_Ac)6i4JSF*1}eq_8butO3g+sbz<&EdZ-+fvV*I zt7VsDfT%5Ds$r^Ou3@QRtzoNSPvPuktK}$RE@7!*Xl7c#TEkewv5<+8A%!8BL6gfb z9pnuVu9Ec3D@)ADOhL_iRtnX%2&vS9;>?`<JcY8voYGW~6gW(vW_+-(QiYiT%9#p? zJfB#ikeQc~TA^S7()1D(kc8D5g4AlV-D1iqy2T1Iu=o~pYC-WW9!UCyr0-jNpah3V zp10VcX;72x7FQU^nXW}e`9-(b((;RP6H9KffXb*_ti`1TsYOMgXuHJ`AD@|*SrQ+w z$$5*VI5j8j78}^%#YLc)E#hQgU?>7rutnS;77vKvWnf^q#Q`mpidQn;;*5_^&PgmT zj*nl-@XJO&BR@AaxFo;GP`@%KF;5>{VH6}Lr|O3kCFT|Tdiv=5R>Ero{k*cg{IYzI zH}ncBi-Z^$7!*LsNd^=vOdO1Sj36w<D8vZD9E?mL_?L~fN*Gbv>%shx%mmT}#USso zGB7ZJ=wclP28I%b5=L+;V5(tU$jHc0!;r=7!VoJ^%T&X(fCbbn0HusVF>u;o_A3JU z(@m4LND$;{kSVu#;^T``ONvVpOH$+GZ}ETwDjrgz-{LAxErEuKCUcQ6D2!OaDsQoZ zRTPPVjFA8lk|08gfq}sg<Wq3<$0)$a#aJbd;wOYvekdN}hh{rxknVeo3=AoZwM;dP zDU79zMM5=<%?!0nH4Iq{E({YGWB1iE*D$9rNkVw*YFTPnvOp>siYyA%K<q50TGkrY zEaoD=5|$KZP<Dl6yjr#zhAh?^7D<L$_7b)l_GZQu)@-JtJt?d;>?v$E3?=LfIBHn3 zm=-eDu&1zt`JCwt3mF%1f%uXPS=?Y54jYCVwxT^H><h5Ta@sIJRM)Vla7Z#RG1PL@ za@MdF-2!D?Nrqak8s-!(Nro7vTJBmNkiD!7k_;)_<_ydXU>kW#*cb5Du%_@VWLn5r z%Ui<<a^nQXBDosQ1$;HUph(UWtKnS0U&FJIae+VzcMXpOLk%w|nrql5Fc!s?uxAO@ zaMf_uaHa6}GBYyNa4is8$gn`5hB<`~RG`%I<q6erqKnnAE)cHa1DRgKoWd{3kiweH zJb|%jRt+myC%+^EIAoX@8ETkQ1SA<IFcxj9VO_wV!d=5H!2s6Bv4FpZ8|0=MRxqC> ziwUe(P?CY00a8R2sufDqfZCrzRht4*i$FQJSRqlNI5RIjCl%aU$jz)u1hxAVic3-p zplumYo#9xNUJ7pVAR09w;pCkB;!;q>ot&7H18Uiz)K88IF!_><#1e&~)KpM0mspZo z46bH!^HWlDQ0fv;bD$`-q_ijxq9P~1xL6Oa!mmh#fq|h)Ik+IPs5li~YZfbjnhtsS zB?`r*1qJy<C8;TT;EV<?N^h}%d~%CDwW1(Cu`K--7nBiSTzHEwH?blf!H1Mg><kPH zFF~1JlkpaFerXAMQ3Wc)G)0P}LAj9yRL0z5Nl7e8yv3PW9A6A}*)5)+(!7$)+*EKu zStJWm$()m#cZ;<I)X2WYRh3$lUmTy4nVouz70kZHRFZ#-EhV)qGdcAZb8%_zE#`v! z@>{IQ`305nw^(vZbK-BYB&MXq-(t>9thmKeTv!Ayc+287`61;3J2a$kv4Fy?NDgEt zV?mKBs8|3MI$&26fr`N*MUa>hNC7LT4lsmN1xB}6^74yvZ?S@#PDS8C1x$cT7I_8+ z1~X9E;0G!R*;yDl7}=P382K2v7@0t9Mh->}CN5?kMh+$cCMiZ1CMhN%Mvy8VCIKcP zCN4%ECNU;9CIKc9FprN>h>?qtk5Pb0f~krhse}T>l_pyesDo7mDrk#9B^fv#-P|-8 z{WLj1l}lb?E~u>JijU9DPbtj-v3cU-3riDopfc?7@hSPq@$ulIzX((+7J)o)ix*m{ zB&X)&#K$A~5nPUfJqjw5LE%*ls+KqyS%erE8UFJL^00!dHchr7M^NZ#F)%PhiGpKO x58Pabkf5Gh5h&^0V#!S`$pC9XvIQQTw>WGdF1G{K+r^+1#lp_P$iv9N3;<4?_-OzD literal 0 HcmV?d00001 diff --git a/MyOptimizer/__pycache__/optim_factory.cpython-39.pyc b/MyOptimizer/__pycache__/optim_factory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6bfac29b2f804c5be7a3f18b97fe6d9401a3cdc GIT binary patch literal 3377 zcmYe~<>g{vU|=|FW}I|Oh=JiTh=Ytd7#J8F7#J9eA22X5q%cG=q%fv1<uFDurZA>3 z=P>3nMKLilq_Cv0<}l~7M6rO`Y&opCY*B1rHhT_xE=Lpxn9Y&HnadT$1!i;RaOd(w z@qpP}IlQ@iQG8%FcMgB9K$HNO&66XTD-<OJX7lC<=ZZv$fZ2RGqPb#GVvGz`3Mq^! z{5j&e5>XPll2MYmQc+U5(oxck3@HLpGT{t)Vk``)vX#de89*>aFiIwpA&<#`Aw?oZ zBt^7^F-k5)GDR#!yoE7JK7}cmK~t(KM!~<JBr`X&Dz!+#Eit(yzo=57TwlStw74Wc zS0Ow#Gd-h3!6h|0v69OpF*!RmMWG}=J+&kwwMZeUQbAwAIlrK?2&~A+z{o%$D8IN^ zAsocdP0ZtZ2?|+1O~zZS`5;SPGBPkQXfod7bWBN11KXv^c#F$1B{3tlxHvO052VmB zB{4Suq{b&dKRYoaH8BMw>X(?3n5)Tni^DI!EZ@BdEXd_ohQwnH0_g;44Dt;wD9SGg zanoeH#S-l95}?U)i`T8RI5h>P{1zve3(^2(`Z)SJ-Qoa?2D`h2xIO=v@wj8_k0*93 z8IqYno`7O@1_lNe1_lOaPzuRmU|>jRsAW!PsAZ{PSirE5!G$5#ErzL<wU(`hEsL>O zp-`ZNsfI0!xtTGAF`KE_p@ww<OA6CMMi8rpsf?k>r-o$#YYp2%Mn;B0ff{BcHcJX~ z3d<a(8s-{C5S<Q+EmprGkYlP?^^!6Zi=RW?#GuJ^i>)BBC^5JA7Ed`i!s1gv5vIv@ zOCT>XH#H?5ESp-AT2y?Cw<xu+G_xqRI6l27G36F>PHNsQ*5ZuBg4A0asd*{I<(VZJ zx7ZR33R3e@G}&*l=H{oQ<`juBFfc@M7H4M`#OGudm)v3nTM)&Omk+V(77NItTdZJ< zii8;$7{J6YEB%c8+*JL_oWwl+^8BLg;)2BFRQ-^m#Jpl(Pal2XN_gVa2d9vDNV2HZ zE2z9BkeHGZkL({wP@scc!obMF$igVV$i~RTsKCg>#KXkF$nmd8iGcw%wm><Z6&71O zEDQ`a3|R~d7#A|sGS@I<G1aovuoOAeFcmpuGSsq`fI^zFnW;#xhH(MQLXZe+4Py;U z4Qn%VEn5mxElVCp4dVi~g$%W9Da<v@Wei1WCG16UB^+6t3%C|C6sLJHq_8wIHZhiP z*Dz%9)Uc$m_A=G7moV2bq_8zJ6-}yP$l^_5C}o7OYS>fQZ5Y6^9L-Ec8)_J`_`tGY zRt<X!rwv05^8)@9u7!-X95oy%+>#)7GpF!mGZkN|VNT)AW-9tq!@fWug*}B~Arm7* ziC~sc4O<#x4MUc23SSEU9HtufC5(NHwVV@}@?O+1rwC*-O<*iCz^O+NP0yAZ<`f~Y zo`4d;ED?~|nG8r~PN`u|5eCZ_ln7>tV#t@&FsF!s<@-v+Fl57Om{UZ-vg=9&v&1pv z&1#ra#K7{GN+dwCDdH)7DH15Y5~^WNkp!#xQ6dRagT?QRd2ec%Q>3z)CNLFAln7=? zfz%=SgE8++4ReY#Sl$g3a}$_~jZzqD*i&R|7)qouRL`qnPLTzxPAQSWkZq`8PLTu4 z_LK-_$%5ol<iV~|Kyo=_UPKLZiXvFurV=@vs*GxwQ<T7}E|kcFRH3+Aq=q>~87%vz zL;+Pc??(-Dib^)q1m+?cl=xNzsRO$QC57ajs9{b~1*`YLp&li5<jtsIPEiA^&nZy? znVF)VBACetOC5O?HOwg*VEG=nJRg!g$p4yP`4uI~ApHoR<=NCQr)YuYF4V9uP)XrT z(FUh)RWx-XDS{K2i|(Z8Okgg4;Kk6ySfbX%SfXCT(ae~lo6RzTr3jR7^duQl^s|{j zyb_HX&Su71&Kk~O22BIMTl{(XC?!mh2m=GdE!H$pWl^Q=7~mSAkeHXE;2i4Wr~oZD zQxwwjixj{r6u?CmsOBuzWV*$aQv@z1Z?WeWl*FeN6yIV75u6}iQff(J@hw&e#Tftc z|NsC0Z!s6Ar`+PmOD!%*Ey^zo(PX;Ck(-~JnpaYqTcisrLG>6I7~F2LfU3S*tRS-d z7Hb}ezQtMup+ICodbP+p_lH$2_8=3=Q{syfOEUApCfs5vPERSg#gUkjn39@Pl6Z?V z5mxaQse*jUl?bnUZ?PBU78exd7o?v(0|IYA#<3>m6l5e8fgBRW36(ENyTy^0UzQK9 zKDhGAVC*OkaG(^Yr$q6BSuoE;ae}!Zx8C9gk@4m6pjsu03oc#`Rg;sLn*`MjGggx+ zJq--r*{8?!ANr~h^WMHn&@-<rF()$x5d+|gu?W<@xW$<RZ#ih*V$Lrpxy723U!Dq! zos7id#FCPtB4tnzgQ|>MT;N&@Qay5HBo@af7Nl0(VlBxpO3t{&l3bdSc#9{qI6ko~ zF*7GIDJS(7OG;*P2{;6c)Ig?afWnfsxF9F9<Q8*5enAlvNPs!m-31(Uw^+a}4OS2t zUIc1j7IA~Lv4IF)kWvmuXbcsB8W%;NI_nlYv>j8#15(BbB0%+7ksyf04<ZCWTDgi- zOHxwP5=(PRiuk|+VTn1Vsjfvu`9(#b8nX!0X3^xj#gbT*UR)##s>p@XGIL5&i{e48 zi1@_3l=!5)C~!JO3MUp&Wnbh3GLZw6gh9F?iXi?3sfRQh!8JXEPy{I#NG?iEEJ=-r z=MGTt6i))RRv5V$*%&z(Ss1w(IT(2u*_b$(I2gs4_!vbPxj>i&Brd?n!zjnd1trB8 zxfnrMgprF0L`yMpG0H&2I2buVy1*Eu10)VINeHYPq>_gbM1u5zFhpF4Nq~`siI0(w zk%Ory6x1SRDiQ*PC#cz?$rz%kaEm=YJ|#anKEB8WBo2zrB6AQ696J^u7AT5}j6p1L zEZTuspnxti1hF_kgf)l&$D0j^1&%x;5DOIfx459$DZV5P62qn-DNt&N;s>QaJ$UdH zgX6Ub9_5e}0BV8W;`Yof$S*2^L^dcbgZ%@E3y^OhEO4xF*yQG?l;)(`f$~T(DA}?w WbAZ}Tj6CR&iHT7H96wylT>Jn^t&r9L literal 0 HcmV?d00001 diff --git a/MyOptimizer/__pycache__/radam.cpython-39.pyc b/MyOptimizer/__pycache__/radam.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11a200d1a53e2fa7f59e54daaae3f30c6043ce52 GIT binary patch literal 4051 zcmYe~<>g{vU|={fWp<*MI0M6D5C<8vFfcGUFfcF_doVCCq%cG=q%fv1<uFDuf@!8G zrWD2$<{ai+mME57)+km+26u)OmK4?&h7{IRmS*NCHg|>;wiNajh7@)%pFM>sm_d`H zYFChBN@A{pe?duRZe~?#ksg<4Zb43JZfaghVo7Fxo<dG$T1je(LRwLNu9ZSYNl8Jm zmA-y@W=TeAl3sFtuD(xZWocrbPib;uu}@~HKG-;}fW(5-A}fUi|2&0~j8uiN#G=f^ zyyR4c{4_A%F(t8}B(p44!6!AbC@(WFT_Gs3Bvm0XFGayAwK6|1B|$+KZdYPaMP`{^ zeo?x9Vp6fbp{0R^o`JEEnJL#xQ26+1GTq{YyZ#nOQEFjnW>IQNGDs2`g8~r5FE(Ld zU`S<%VoYI(VoKp@XGmjAVNGFc;f!Jiv)NNPS~#LuQg~CiQn*_fqgdM+SQw(%f*CaV zZn1)W_x#VwkC%Lt0_;~ZrJp?m0&ncowcEvRp8NjZ{<&3;!K1!PSGyuc1_p*?HjtH2 z%+A2T0OB))qS1<hfuV$<hN*_RhNXrzg|U~VmMw)Tg}H^HgfWGsnbC!znX!f;o~eYn zgr$a|nP~xQ4Py=4LMBFr6oz01O;*2^EVr0)if*wcrIsWX-(pTJD89v0o|>7SQ4*h$ znw(gf46+|;GXoO?0|O`kf<ab`FfcHrGt@A|^3*b>Go&yqVw}iS$P&z;$*9Tf=jXSQ zL6hkglb*pX=FGg5B3=dt1~Bo<Tt6c}H#N8<zsOL(GAA)lzdXMvySN}RIaNQTC^4_t z*V9Mew-O!|`bCK;iMe_OmAANTax#lclJj#5?D!ZM7(Rm>Riy+DO1=2_%)HE!_;@{= zoc!d(oMJmYgyNT=9HGf}izOv9x#SjWQDR<t>MgdU(zLYHqFb!Rr3I-)MIcYz;((d5 zlKmD-acWN5Ew+NhqQu<dB2JL&LB<qugIGKa3=Fq8Qc}|rOLIz!S2Et>jE_&wNh~go zj|cm)NC2ctf`Nenq_P+kSqzK<j3SIxe5lS%1_da{A0P~3gW?o~i$Q@~!cf8pjx(kj z#)XWK7+cBcr^#F-4l*C?W}f)?;?$DjlEjkK`1qAfMN%NAv4Z(UvLIF9fCsxrfq{WR z5EOABix?QIgkWw#7~qHEARz_@1`Y-W24|3_1x%pwqL!(KF@>R&u}G+fv6-QksfHno z!G&Q0Q*1*ma}9F}qa=h^P|H%ok_A%9P-Ic424ZJ1Enr^AP|I4wn#EESRKl9V)Xdn# zSi=GpXRBezVyj`1WT<5?VXt9tW=vtuW-6MI!d%0i!eYZv!m)s}hHW8Z4SNbJ3Xjc( zAx{M)D#@^b3!$QhJ%tTPgw2Klq6Q+)#8Aso%UQ#m!Y;`W!&J*v%U#2{fV+l!A>#s` z6pk8h35FWh8rFr(j0`m#3%F~z7BVj2P2s5F0*QmN4#-Wc43Z2f9Oewn3^oik3|V|7 z{2)<Dh7?XQhGs@dhFTs7zlJA;S&|`}X#!(WNDWU4ha^J{PYRbL0~13nZ!KTmF@!1` zh8k`NPne;WKZOh85^07M4he=@ff|7nE{LoMLk+(OLkdR?zXXFcg9Jk@Zw)WVWE+MW z-W2X^)(MP7ziOCMcqAE01QrNF)Yb5&@JcX<Fw_WuXlam2E-;^0nxRGjEY1ZoZ6PDr zWT;KT3?Q*uzB~ztd7v<WxB(_p!w2$>4MPoc3ZEoH3UfB|1g4?@D376(5zJ#-zzY%K z2C3tb1i1rT1VP=)R8&{Pxj=9s16UouB*Q|+TEQB&1wu7^+zd4wHG-g+$m6R4nFAI9 z#{nqi2!mprk)cqnP@)D@4+;1cNi#4o{89%eZH1Kl)MACa{1S!Y(t?8gq7sGTg2bZY zRE6}S#FWg`ypm#Y;d+auxFogU7JF(%L40Ca`YkRfBfhw>2$bhvf=a+5VFm_<m!O>U zk`<J#KqbOUP!22tWpQw^peb^T2V7pnrx)dy78KuN0oiklB_**W@fK@ZPJUv^EzZp1 zc!(oz@dTCTm1O3ox)v4X7ZoWnFfiO=&PmO?#afbIl$>#kt17i9zc@Z8GduMbdr4(M zYJ6hxEtcHUocLQDi76?`U=~YaN=kf@7+7y+Udb(%+{BWMTP($eMI{hbDVb&Qw^)<& z3o7F^g&>6kdvZ>GacNQNEtZ`8;^JGZV8`BKEGW_e8N*!=4{~pOT7j_<*il8QAn&Sy z<iK@=A*8A>y2Sx0K~qz1vHHaq=O*6b1O;P!ab{KOEp||CViaF;ixVsrpPN{5i!~)R zFFzMjJ|YSw6$S<dP+3=;1TF$u7&#c(n0T1@7`Ygk7<m}k7&(}@n7EjD7&#dE7zLPw z7=@UG7<rf^7{!<b7{wTc7)6+P7<m{47-bkan7J6a7}Xfr{xkh&VdDG8#v;MU!6d^d z#mvPB(#OXnz*NNrP8oV2r)shkfuj;sJQsl~Rd8z1Wb|9f4yp?B5_41I<8N`r$LHp! zl;(igJn`{`rHMHZnIcg6UZemD9grt(@j~l><kXxTSUmvtC%6~}m%yO(0!mTf8i9j} zkClfRT=USRk--&^lbD$Y_AWMm;%-ua;<^N_Nm0UB!c+rlJ}{#+ADH}#Kox{0OA#mz zA=Q#UC`7<vBMD-GeGd)?a1ekbqCoWtC@dHl1(>RY5$=Rn<0#b-DAR%JL=eVVFBZu& zFfhPsMz9mWF{6mZ2_T7hkP|=#GcZ<(qc{LpWeBP{L4E~ShI;rbLq4R+kTLe(2&)Vc zl^9P7CrTy8Tf>{e38_Ft7;1RHRSgfQijiQb<*Q)^Rg2&%gt4dsT-bv}z?B9!Ql-I{ z!UL}4zSJ<Ma7!|j@Gam^;i%z}V5s2(Rc<1n%8eIfY7JisC%BU1166onaZZpp)=G}Q zhI0Wwv<l(_S3v?bYzqWx_`y|>0JsWLgV=~s4f2Cs3swoQngr3RCVoiORHO__k>DBy zRB9kfS8!ERqz;nR01@D71ysuvX@a;~AOh6*)D$Yx2C+e<b&)QJr3WJPL4*N_0M#T# zh9H&^h%g2bpuS3xDTrkTBFsUAC5QlZABwC&EE`b%5d>EeMYbRTI}l+HB0&8`aP3g! z2;zb*fY%B|P9QN*t|)Q`u|WOGB3BU04Mcc?2ypfX6W}rc<Z)1OUmQY4nGY`OnfSn^ zI&x{x!VD_oB^be_JU>#F##_`Ed7-)wTv~x02`WP(K&1qzdVrSIATdpjTkP@iDf!9q z@kQVu1ch>xD7ezr%LjFTAtb01Tm&j%ia_NSxR@^j=|`#n;URL1!v<2v*ntv8F{p}Q OVc}rpU=&~$@B;ud8*4uR literal 0 HcmV?d00001 diff --git a/MyOptimizer/__pycache__/rmsprop_tf.cpython-39.pyc b/MyOptimizer/__pycache__/rmsprop_tf.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70fc2eb4e1c4a2ff4658ffaa59558f13219a5a73 GIT binary patch literal 4619 zcmYe~<>g{vU|={fWp?6waR!FRAPzESVPIfzU|?V<)?r{^NMVR#NMTH2%3+LROkqr6 z&SA=Bj$&qHaA!ziNnveaNMTK7Zf1^RNnr|R&}2*WVPs%X2=WaMD9SHT$jwj5Ov_A7 zQ7Fk*NJ`B}EK60$$;?hw2uaN=&M!*K$uC#P%q_^_;_@%bOwY_q%*m-#NG>f=P*W&K zEG|h^NGr<ERS2jI$uCOIfEmh_QBqP+Y^ATCo>`JnnxvPUpQ~R`Sprf5r;~E>lk^SD zj4e&fQjC&}(~Q#$Q!P!66O9s$lT1xg4U&va(hQ9)jZ;#MEKT(xdh-iPGIRBda*GR! z@(c6|D!F_zlT-7GQ&SX5^HNfa6r6%xbe(e&ON&z#j1@F8i!(HJ6o_?7F32}U`aYh{ zu71I;TwK23a7j)q$;{6yR&dTQs4U7%&nQtaGB7ew2+A)mRtN|2a})EpUV<XVPm}Q$ zr$5MnnN_Jp$si$Q%mL*TTQD#%q%uS?rZ7Y?r7%V@rwFt&q%o#&rf{`zMzMg|+$lUQ z98s((f+>6{{4I=8Z0!sz3{mXC44Oi>ID>qHLB0=hOLS#pVDQW>$Vtsj%_}Jeg*`|@ zAu%UCzbLaLBUeEK5{GV}I4mxy%t=*9Ehx^+$<Nc|;!;phQ1J5)akWwi$;d2L$ShV! zR7lAzN=+^S#jb8*UWzUxe)H3iV;5{gd1grl$S8%xf`S}S(n(59%P&e*C@w520UMl{ zm!goUkd|7mkYAFKS_DmFu&^x2S4hsuF9wMxmLz8&r?j;EB86O#U}j#rLPljlYLRY1 zVo_plu^!aL`6a34`9&oel?urjiFxU%#R{2u$vLGdsa9Z%4fPaUAzp;LEi<n;GX><f z(!`=v-J<+_h!c(U6p9N$Lh*@Z=?a;7nI)NtIhj?Ee2|})S`0SNSWm$xNTE16F$d(J z{4|Bqf|SIPRE5Ok<kH;IoWzpU6ot$@h1`5l6qM#FB$cM6r4~Wl3d&viptPA(so<`s z;E|bElAi}soskJKJRv#1w5T{$!3LB>t@QQF%ggnWi}gzKi}LeI^7T?vO7-hXGP811 z^YoL8lZ}lG^^0>dQ&Nldb5fH_N{dqCA>8<!)MPWgf|NA7gm`E$g=C~EB&X(;q!y*7 zD3qlZ6=&w>DWqi<6_<d5B{i`K6l@9!?x}gHMTsRKuLY+TmZs(<L%kgiir}Es<kF&| z)VvY}ztqyA#2f{`)ROZ2qU>UbgF)Flv8W=mOfSDEUB4hDP2bSiz(UW!!qlM56xBJ7 zMd_uW04jz=5-90GVpTySvm~`BF)1fi(@FuxRLD<L086Kqq!tx}iY0Ki&#X#S$S+bz z$xJROR!B)r%goEnO9w|1$UG2+Yf?xr$}cT|1Tsh>r$|8~Ehj&*L`MN+RDND!j;57D zPHJKi$P9%dP>gD%q^2d7=9E||7^dnPX~GSLM=cf;;Bsk+$tC$k3J5a{5C$aX6l5e~ zv%feuKffdc6!;n`smY0znhMGJdBr7(c_qjO=~*I-N-ZeHW>86LQLaK_N(v<HmSm(V zq@?EM=Vs<5f*e(n4=TTl^2<`;egcI_UTJP>QD$;tjzV!sVp3*KW=SQIwT7v>hDd&Z zrzMS~{QMlGP*ku|$V^j6NC+t^O-)G9QAp0uEhsHXRRB2xt_fPx=%CmSQjuPin34%f zV41}VdHF@T&<vPVsgRhbkXl@lnF}s`K?x$Wq*$RWu_!Y!FB##F@>Ea>91jjXYykt2 zSAfd<7%3E_<|XEoRB9sm6v>w<smb}J1v#lHFcZ-{YXve1qCx@1=)_`$g47~-O@qi4 ziAlx!hUSKbdIqM328M9|=M=?f=EcKP4|<@YCrw!K11hE=g&nxi0fivk3~0dvidh(j z6;%+&L6bK)<3p+iP}L48sh(GhtaE=@)nfnr>=_VvWB+{j-s?B-*W9)(Vq{=ocn+<* z7~DdV!L1S~pPPY!0mNqpwL|8B+98ZJOxcXZVkHbIObZxmm=-cNGd3}%FxIeSGZyQV zFr+XqV60(*i?U@i7CV(Nq_8YttYL$TvSu?DN0cz6ur6S%VTFq_XEPS(l`y2REnuu+ zhKQCh)iBkt)iBqv)UemE)^OBtrm**N)N++Dm$1|@G&3z=tzoR;TFAu6kirnmpvmD^ z#LmFLP$lV^SC*KQnSxpXSSeK3BBY@8q(WI@PH8GwN(!b69^?o)IhY(IRX~#uLS7su z4=D@a5}F*hm~x74alnkc#R^t%i#fHR_!b8&|J>q1Hue@jS~9)Ghnk!;*=})#ft>AH zRFq$Ii!BY5u}W^SfC|oAti`1TsYOMg2)o4*9}lYb;^Q^BZ?P1o=A_+X1DBx1MWD7y z5h!AdKrvjz%fP@;#0esJKm;F%;Adc9xWxfWR4W;8amL3d=Oh*v$H%W^_+_J?k)N9y zT#{d8s9%|rn5Pe}HwqGyQ}siN67!0EJ$>|jE8&eLSc@*cBu%fNvPgu1fx!Tjq+~$p zjERGhj}Z+s@iDUeXJe`of;WcrV4g{40%?L`kUv=&7#Kh_xZzd8P{IgK8ca2e3mF+1 zY8bMZT^M3zYME-77O;St(4gc3W;3U-fY^m%;Kak~SHubOHz+$4fr7zJleI_~Bmy$~ z7EgS9acW6%Nn%NAeEcmQa5%<8YN1<P#i=FGfYM|x5(NbuD_G?%R<Md936L>TAVL~M zfPC!;@;)dU7#IZ@xfppEtHhAKim=NM#lIkjfpj^8v}-UiFr+ZnGSx7qFqSeF3Dqz* zGt@HGFk~^fFic>I<)~$@VNPL^1o0SSf7G(nuw;Q$G89=9s)5*9Otq{vtXa%OekCj^ z%%F^1!;r;V!y?I0%T~fx!`950!jjEYv>}D1hAoBFhM|Og0Y?o>7Slq;8nzTR8-^O@ z6n0656qan3q8TOZ3pn9&9AG(4Nrs{gDB@f;3?QA747Kbj+%?Q8Jdz9&47D5}(^(lL z8B%!78JHQsdilWSGEZPEQb^&iVJ>5sz*xjr!<@n|31P8hF)iS#VNT(bWZ-51+snkr zkSA2bx&UM@pCnYQmbHd`0ap!24eJ7)6oG||HS7`$pg6B#PT>TbR93^5!X?R2%UQz- zQCGtWRl&Z1H$||9vxajaGuTWaNrr`t3-}f?)N<ts)v%+BA;Mw;QxR7UTMCCHLoIiT zFx2h$YM4`m!8YEj;a<QGQN56Hfj|v&iijjb4Qmb8LS{y=s~2$9aC0+&(r%s>np%)5 zCb%lFPHu)8Zjg__G0Frpj~i}cp<1Ct4X9HnTD2`8wFp#JfI0%jnR)3sso=&^Ze|sv z5-u)DEr2#JK{X7PHXTS9)UYotN>$KEPRz*xHSCZY3swq_3NZPSjKmU!qSRDSS(sRo zS`2P6<mRWO=D=%7kZ~Xk>P!@+mXsFdK~&}B7Z>ZnRrwW(feLj^X!cM@$xj6}CQ1~F zOA89}i%JxV3lc$%Y*<^P7+j;>VgY&Q78jxm0M6dG_+dqHJfuPeXZ2egAkCnDLOP^e z`Ctz&9bbYA190gGDidFVN}iYh|Ns9FDK<exB)BNm6fBYj<$4y7F}GM!5=#<qab_0B zLtJ)?C#W>9Br`V^TwWK+gH$r-q~_gX1^0Pxae_J{@t`jGEv~B6qWoeAn<X(NCH@w3 zVoJ&_=7Rk4TkMG`DapB|Ik#9ry|DOOEV-pQ@wYfYqA8hW@tQo4l7}4{0k>E{!C0gK zvW>BzNCQ;xfUGW3Vqjp1;?K`ZjnB=@D=m%(*GNSwAa%@%W$Cw=lS<Qyz=ai<0GC*b z3=9nJpdu##6h7=Mj2w(?OgxNyj9iRNAT}chBNsCdBL|}ZBM+khh-Tto<Y5wE;$swH zlwcBK7GvfDi^?$xG4U~SFo`hAFbXh>F^VxsF;($Fqc+4%4-{#d97Uk9k0MaHUc?4+ zH#izVRlA#8h$f?-CI_f$%S+4!6~$cf@wxdar8yurPkek~X<`mkhCMz$B|kYn9$eiN zfePIskUwtmLMyW5)SR67cqFfai)pZbK}9|&#*4w_EhCE}10%zKK0zKfaK)<0T4WD$ tmM$nHxxtaB2kLg@7J(87SP7D;@KC(PVFPig9mqe$AfK^taxn5R0|339EjIuF literal 0 HcmV?d00001 diff --git a/MyOptimizer/__pycache__/sgdp.cpython-39.pyc b/MyOptimizer/__pycache__/sgdp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b747bafd446e4ab8cc65806fbf09cb1f38430e4 GIT binary patch literal 2915 zcmYe~<>g{vU|={fWp?6x9tMWTAPzESVPIfzU|?V<c41&(NMVR#NMTH2%3+LR1k+5p z%u&pYATj0~mR!~-)?BtIHZY$lhdqkjogsxKg|&qtg*BD6nK_Chg(;XplP%GZk%56L z*xe;S!M~s+GdHs;wMfA;w;(4qH#M&$u_QA;Pa!$KATu>ZA+0DsS0SUMq@dVJUq3yw zB%?G*FF8L~KRGAAEHN=t-!UaIH$XorCqGF)H?g=RwMaiPB{8=^zc@XmK(C;Ziz^_p zAhpO!At5*?zdSQ9UBM;4JWrt{BUK?hH8VY<M8Pk=C|AKVFS#f+u{c#BGf%-cALNMA zT-~I^;?xv`1B(+Bbm1;ZEUL&X)5|YP*H26;);BURFw--zFfugfa?VdlwSp_d<zOx@ z=lp`oBCrEAk~I~K42%qP3yM;UQ}ap`{2aqvgA|<eiwg9(d_6-Hax#-s^NLfsUV>u9 zPm}2uCp^$@aTKK%mSz^ErX+(Tkue99Qyjs-z>vxi#hAhn#gxjN0*W@q6s8pB6qXd0 zG>&$LG{zLx6t))5D3%oV6pj{#DAp9t6s{JAD7F;t6rL7_DE1WI6uuUQD2^1N6oC}M z7RD&fb_N!PD6U`zP2pQCplB>&WME)$OSkGVc+^+vYL`B1CJ5}bPv(NU0z`2$Fff4l z%%H@i$H2f)!cfCh!(78s!&<{u!(PKt!<oX^%Tdc!!dSvo!_dsMfVqaThHD`cBSQ*9 zFoPzOpC-pGrktW%9J$aay2S~K>(so=y!2Z<<>2^?Pf1Nqth~jMmkLS@`DM45Qwxf3 zv8JTvlqBBbC{Kwm0wpg^mRl?-naL%$Sc^*wQj3Z}Y5W#Pe0*kJW=VX!Cig9t;?$h9 zTWkf1MTxn^MXU@A3`J}pf*nMFg0_eg!~%Jz2;{FK9tH-6TO28=X^EvdCB-WlZ*j)Q zC+8#<7sto1WcX#KpOK%N8eEcJWT;=6lbEMpo?nz*T#%TYsvlC6m{;uU>7(yk36EA- z>Mi1DU|<jg1sBLl1|}{>j{j^-RlJ}e*Mqq>nF$oQPz+*&T<Z+dB*ehLP{X(Ygc(y9 z7c!<WEd*shW<PLrz5M_G|9?%UTP$UnspYp=iZiQHS27g|fc(K&Q6viT9+(gViSWiJ zXC&t3rRKzgREvX52U*3ySS18;46-zai$K8(b`dC)KrR9m9Z<Kt1SOp!kcpa%MG_!e zr9cGOYA^w|pF2J$u`;y?Y&%H47;3u!#CC*aGRXJPz+eD{8Ys1a?G6Ej21^ZNGh;1d z4NDDEGh;1N4PzFl=wj++T)<evypXY$wT1~zgsq0LhDn4WiwWfBUdCGX8ukUu3mIxy zr5S42q(SMD*{?{Jfq`Kq<1OZt%v?~?%F8dxy~UE4k`jN5xwtg<7IR`!u_gyNY>T8p z!O2)z1WG!$I6<KqpH`Zee2c9j9;}V65<)Yl<d>kwf>3;NesN}AYJ72KZe~tmQD#Xc z$mU{Dure@mFmf;nF$ysXFja{|f(1i$GARARLI<P^6hh$8IRq-)8Ee^7n9><j7}FU` z7-|@^7@HYW7$q1)7{nQBIcgZPm}(fZnA4bC7-IWsIcpdeu+)G;5R`j47qZlH)o?9f zTgXtuQo_D~qlUAFF^jdCaRDbN<t=2CX2@nJ(gKOWWU{zY7*m*g85eNZa4ck8z*EBp zQdP@c!(78%!&$>!!z{wk%$UWCtV)Ezf}xhVhP9TxhPj44jfshYks+9&Pz%BfX3%7T z#05A#70EL&Flh4KV$Mm;yTw|Z0jdaOKxs=3L}-EHfh8reB=Ht=ZeqnPmfXaWj9V<l zg+(Prk{|_&AOf5ql|U?IkYPOFIK0JJaEm3qC^6+0dqHYZNoi3Mq6~n<D^q#OE$-Bc zg2cR(cu>YFQUlq>g@}(^%mwjz=+Vm^Ur>~vm6{By|I|TQOav5Fj4X^&i~>wTOiGL@ zj7p3wOngjL&}5HL?uU}+_!t-%K&28mW^EWj)ix-tr!bT<775ibHZ#<MVv@myVFF{U zOf7Q_a|)v*LoG`Ua|)9rLoI6!a|*L0LoHhka|(+j11PLhSS1-;7-HYma@25SF{Ut- zG8Basx`EhPOrVsI#azQ7$xzEx!cxQ4%$UNK%~aG<!n%O1h9iq<A!7|!3X=^3h?Hcg z<*wmgz+S_;kZ}P=3VRKk1VarcsD!U!FWOSWS;Jbx4Nh%5c`P;DwLB^8wY()PH5^%- z&5SAR*(^oxY8V!9foh5xo*L#94oQX@<`hma$tB6KkS&I(mZz4th9!k9n|T6bkpjs6 z1?(y8HOwj8U{#Eg3?d9Qydn%K>@}<s4AKk|3=0_<8Nli|AgZ_-YIr~{Ds(H9r~#Fi zJbp!>Y=Tq)fXjYRxeP8y!MPtT1%RszP3~Jf;2I=8y(qu5pt#5gl#p~mNrAPvB(Ws5 zNE^&9$uCOIxW!eKT9jWLpOcxLdW$8uG$+0YobGf$dMrSRj}Mag*pqYei%W}AZ?WX$ z7Z=}R1)Gi4g@z!lU^}fq_An=vrrlyri7x<GR_KY3rMM)uz#Nq7%%G``gOQDi2UMzn zFcS|W7b6oR2O}3V52FAx2csAx3nL#B3o{ENAEOF07o!HV1XC3!xZDAmq{--~$qlMf z^AdAY<Ku5}#mDF7r<CS^*gWy^g{6r(5Sb!S0S&I`icm9W5y;a;pd@^Y7h0Dlr{?6u z$0NC18014x1p%(XL3LL#h~QvkW90hJ$Hu|T!OX$MApn-u<h;cmAD@z+93Nj~4GIqq zaJcK`<=tY+%Zn0)@bW?JQ9UpTYLym&>KTX=^+9%lYBD7I5fv1N4a9qPpz5p`RJgNn NaWL{Q3NUgo0RSw{v~vIe literal 0 HcmV?d00001 diff --git a/MyOptimizer/lookahead.py b/MyOptimizer/lookahead.py index 6b5b7f3..b8e8b00 100755 --- a/MyOptimizer/lookahead.py +++ b/MyOptimizer/lookahead.py @@ -35,7 +35,9 @@ class Lookahead(Optimizer): param_state['slow_buffer'] = torch.empty_like(fast_p.data) param_state['slow_buffer'].copy_(fast_p.data) slow = param_state['slow_buffer'] - slow.add_(group['lookahead_alpha'], fast_p.data - slow) + # slow.add_(group['lookahead_alpha'], fast_p.data - slow) + slow.add_(fast_p.data-slow, alpha=group['lookahead_alpha']) + fast_p.data.copy_(slow) def sync_lookahead(self): diff --git a/MyOptimizer/optim_factory.py b/MyOptimizer/optim_factory.py index ce310e3..992231a 100755 --- a/MyOptimizer/optim_factory.py +++ b/MyOptimizer/optim_factory.py @@ -75,7 +75,8 @@ def create_optimizer(args, model, filter_bias_and_bn=True): elif opt_lower == 'nadam': optimizer = Nadam(parameters, **opt_args) elif opt_lower == 'radam': - optimizer = RAdam(parameters, **opt_args) + # optimizer = RAdam(parameters, **opt_args) + optimizer = optim.RAdam(parameters, **opt_args) elif opt_lower == 'adamp': optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) elif opt_lower == 'sgdp': diff --git a/datasets/__pycache__/__init__.cpython-39.pyc b/datasets/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13006df55083dc070f46953e4505bfef1ac8b198 GIT binary patch literal 193 zcmYe~<>g{vU|{gyJ10?#fq~&Mh=Yuo7#J8F7#J9e1sE6@QW#Pga~N_NqZk<(Qka4n zG?`yAGB7Y`GT!2KNi0e9%qvMPN=r;m_0wd!#g~#;k{F)}6Dk53w34BSg@FM={4&zd z$j?pHugpoz(=X32$}TQQOitAgDN4*M_Vx792Wc-(Eh*NIkI&4@EQycTE2zB1VUwGm OQks)$2Quw5$ejTG`7km7 literal 0 HcmV?d00001 diff --git a/datasets/__pycache__/camel_data.cpython-39.pyc b/datasets/__pycache__/camel_data.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ffc96a2ce0ccb33cabde901455fd1b7c9c44811 GIT binary patch literal 1852 zcmYe~<>g{vU|={fWp?5nHU@^rAPzESVPIfzU|?V<j$mM5NMT4}%wdRv(2P-xU_Mh6 za|%NWQw~cmYZNOZM2;<&or!^wA(tbH11!Rv!<Ne##mUIv&XB^A!rH=+!kWsF#nsFl z#ht<!%%I8k5@e^JCgUxZfW(pvO~zYXDTyVCIr)hxsYS^kIb_TXbCVha14AkU$Q@Bk zDI6(GDa<X5QOpn<qF7ScQrKG<qFBK;aNgo{PRvcsaY-ym^h*Yr05cK92Kmw%<Vyuc z28J4jEQSS)B}~l>3mF+1Y8Y!6;+bn0vY4}2Y8c{K!7R2KhIlqGi@k;+o&(I{Oku2H z$YPt#kis;VS)8GUA)c#*JBz1=A&V=8IfbQ{sfHn*7p#)6h9RE6hAD+Lo2h6~4MP@h zmOu(a3PUSX3TqmZB*Oy1g&^HRC^BqFGQu@XDeT!yMfXrtup_BJ)ysh-1JcWx%`}0r zNCrhWCz2Xe-CRgA!YK^F44T}2w^-8hb5g2U^pcCqiWnIf7&IAgac1VFq*lZy=jYsF zEh$RO%!}d%Q}H>8NvS!vn9CA#qBzSEbD(^dlGNgoC@v5Q5z!R6#StH$l9^l*AAgIj zv>+w1B=r_sPJVJ?PVp^nP;w|vEs0M~OOImB1KA$M2BwOOSQ!`?ZgGGV#HVBy-QrG5 zO)M!bN(FJEIKh^Kne55MWgx~arh=4P97U;#De=k0Wl>zkIhiS`@gVKDctB))aYkuc zT2AUMb{K;>CqMZXTS`%WL0;l57Lb}-tYC9BnQyTar{<&;@q(PdT3nJ?lDd+iNQ!}h z;g_X;Mt*K;a7liVp?+mfVxE3^eo=ODL1J>Men?SbUa_yIk3Pt#;?$C2{bW$m1iL}6 zpb``enR%Hd@$uZCB&h^SsLXtfY|LDYASl5o#LUIa!6?Nj#Rws}7^N7igy6|rFFqck zJsA`;Al)DgVly!?Fo2UeD9@LG5;|iuDEl$`X)@hn&PmNH5(CF&kpu$+#2aA06iI`; z#~vS_lbRPFuMTn)$Rq~FDnVpR!LrF<8IXkx3=HfH3=HtBRm1?vS~ZN447E%ppv=Mq zVmC9^GM9id7)vu_4byCf6vnws(hRjMC9DhBYFHLB)`GGedkRxFQ&AQuBegJ;aMZA7 zai*|JGBh)~Ff=pPvXyWxV6R~VXIIu5Hc5tB)*99t<}@Zy(c)LdtXEJHf)W@YM}qQw zFvu?g3=9nE3^fd~9JP!!j1!p(S%Q%~ugP?aNzdRGV+ABAKyeER4=$UW%;J*d{M-UN zBL)VB&mccnX(9(mdTL2#NosC<yq*mt8QbY0^gyx=H#o#^F=ytL6p4cZ9pthieo%O` zmgE;DXWU`|6@o>=V3DH4y!6ytti>6L1*x~#iW2iu@^dxW!LfCVIWwgqiUX3EGgH8Z z-Qr9u&B=)`NG!>?#R1I~#kW|&`4rtZ-0|^n=W2mMMF`{{MlMD^MiE9XW-dkvMlMDU z<|<LF9@J#=^V4JlCEUEk+|>B^TU_z+x%nxjIUqJqe0*VPVh%*6NCxCfIS`=>BEaE^ zAV4Xt2$W<%DYO_wa4<43GW_R~;a~=fYw{Fnfb^+=0)(v~F)t;txCj)3x7b1c%*jkD z0tMkM0dVN)m6l}Y6zhTV*)3Ku8ywz6pt!%q4yp@^Q%k_+B83OoH6VxF;;?~u!VZ*q Tig_3q7&sU~kcUZ#k%t)o)y2P| literal 0 HcmV?d00001 diff --git a/datasets/__pycache__/camel_dataloader.cpython-39.pyc b/datasets/__pycache__/camel_dataloader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11da2c2e9cc3f2bb97782008855d83b10df237c4 GIT binary patch literal 2882 zcmYe~<>g{vU|>*joR>IBkb&Vbh=Yt-7#J8F7#J9ee=smGq%fo~<}gGtf@!8GW-!eX z#gf91!j!|3%NoVX2vWnG!<Ne)#SUh(<Z$G2Msb4Kj5%Dn+)NCN47of}ykHU59KKxs zDE?f5C;>1ZW;$DrV6IS<5Lk>oM>tm`N(9X2$PvvIjS^*KaA!#2OyO!_Na0H5%@S*7 zjuKCmNa4<AD*6I;lVp@+s^kKxg$z;BsnQE%Qg{|JrtnHKM9HS`q{_}_Na35y93_`3 zo+_8ZpCZu96eXW3@4^r(5~Yx;5W^Iu7^Re|oT}8!$jDH5FNHCfK~wN0C`|k`8E>(L zq~;ap7iluy;s{GD%1lhkN!4V$#h#y+8lRD0qRDiNBPFpUu{gD)_!d`5QDR<kT7FS( zF-S>qMrm3aST##PVo8Q3<1Lnw!j#-(kjcoH85ZJt3=9mZ;Lwg@O5sjnN?~qcjABk< zNnveah+;`$OJQ$eh+<9QNa1W@h+<3O3TDvcxh3G1npjd=l<Jh2?vtOGl3Em!42nXS z1)#|0U|?Wy21RKfBLhPT!ve+{#)XWvEHx}yOeu^d%q1)}EX|B5Od<@;jIB&53@Hq) zOleG#3^fe#tROY0DnPO*YS>EH7jP_O$l^?4mSm`9En!b#sbQVXkit5bxtXzsA)c#- zxrQO0yN0QTA)Y6Nt%f0<Hyva)ADHB?VaVd1z*y8(!!&`hNTfudgr$Zdi?x|Cg}s-t zKd6?iL;%EJz_gH|nbCzI)~J@fhGBtV4f{gI7^Yf|TFx4VEdB*THJl3>85wFAviKGV z)o_5>Yzu@JGB7d}?h9u~VPIikVQ6M%WXKaTWSGEM%wfR52!@dij0|~YP+QrcwuV4$ zWrN#lLaMD37z<6X*xP_?Zwf;&gC>Vx5hDWwL;9?lAn?-uC5ZQnQNK!6zcME=Pd~`n zIliDMKPxr4M87CGIX*cjvA8%hEi*Z>Br`ux|0NRx1A`{xE!L9!lEfTMky}g!DYrO^ zQWI0+lZ(r4v1I1tC*NYu%qzLYoRgXdPNA83@db$`8Min~Qj1H#%py<%xy77Zo^p$| zAhD>V_!f75QD!<!A!~ACaz^Sc){>&c%)DEic@Q_G7T@AXgCxD;TWmRrNvS!-MVt%_ z47YgGQ%mBL64T>B(#5yf5(^4a^HOfHl@_EVmZWNO-C`+D%}Kk(3CcI|Mfv$9MN$k5 z3`LR*3=C16Nr~yjU`O3z$;i)5y~R?Hlvs3&vE&w08psbxIV%~8G#D5de%b4T<bzA{ ziwq$lRi0mzU0jfuoT?uJ%7(t4KKdXVKsi%CIWadiCmzJj0p}6Dg34PQ@$s2?nI-Y@ zf}pHn0Lm~-TudsADlB}ALac0zY>aG7Ad-WTjZuhEh>?$(hmnU-hY7*1l0?sYdN4DS zLGc1I5`@_p7#LV#d9a3o0hCS|Y8bN^Qy3)~q#0_NN|+Wf*Dx()tYxlYUcj=Dp_ZkD zbpcxqLl*l&##&Yoox+sOROA8H!z{^A!<xcu!;q&_!T}R$W-Q^X0p%KI35IOOqLLKm z8dhXk8-^N|8rC$XU<OSVNQx`s1qC`cDEV*|Nr2d(a46yju>?TT$Wf9XpI4e&P<e|f zuizGIRccXwagj7gS_ni4gLJYLDS~2+JGr<lJ}omRH9jRRiVqy^pyUUMN2ZjtTdZKQ zTdd$@0**fjp$Rfcj)8%}1r!G$a~W6#7`d1^7<m}E7zLQ57=@Us<S;@RC4nV_attV# zK^Vja1wFX90~uVyki}3WP{LTlkj2!@7|Z}>F@sp(?8oS*$pT4!ApaDJfZQVr@);x8 zXAoC{(pwS8skhkU<8xB;;^V_X&ILs)10x$F7gLo4h7+L*QPKgZTn1sV;|@X7L6K7p zLl$EVqa;Hu6UdQZ7BjfaWol+@W~^lbvzbBc1*|EIDNGBQK<T802^5<(3|VYN4K)l4 z*lQTFII=iFy1^wQGgKcQGxE%!`nX8bSHldk8ytYlelP$3|NlRefr)`ZlN+3gi)2Be zC<7uuRa+4#K5ns;Wu}%xLK+k;kdOwaIB;;YfP?!MYi3?bYDEz!GZyKC3;+i`m;i@A zcYJ($YDs2EYHoadEGP~@wlFZUF>*0-G4e2qFmf=lF^VzrFp4pOND-zg87wgaH^NVo zDI`Rb9aI$LCFZ8a$KT?LkI&6dDa`?~dE(;>OA~V-GDTV-=YXAV2x6In2(ZHt1jva+ zpb8h1e~LkcDhDH%F^4b*8(0)%O>SalUVMBJ4@fsC?^U@&3NHf1JgDRX7i7+$f~+z> z&(KUC!PHAGE_2gl1P8vR$}P5n#JrTmVsKWz#R|^CMLM7y&03OQl$-&Mk0?GcTQ9M+ zBtN|<F{KC;Q&GYY!Mr@Zw9>p}Pz95iQv^z}w*(+Er6rj;#d;vu++qc@Z?S+lQCuK0 z9wNdGR$7)>oSC0j1S<B5Kp7>9w>Uc|HL)m953H()9poT(P(_rJnFOwYAe9HZ3#2s# z4zVILkSeD9;#+L#sU^wfDMex+0oLMzoXnCUWl$h0fdtq=))%Li++xfu0;LFW0=mUo zP?VWhf|P~9F$IdGTO2l!ShfR|7sa5c<zV7q1eH3VGG2&L0LEq(N(6~>F!C{iWH?wj F7y-L2y3qgt literal 0 HcmV?d00001 diff --git a/datasets/__pycache__/custom_dataloader.cpython-39.pyc b/datasets/__pycache__/custom_dataloader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..147b3fc3c4628d6eed5bebd6ff248d31aaeebb34 GIT binary patch literal 7729 zcmYe~<>g{vU|_iWC_d?p5d*_x5C<8vGcYhXFfcF_7cep~q%fo~<}gG-XvQd}6owS0 z9Ohh>C>BPD7;6-33PTEW4qGmJ6gyaqC5I!IGl~<;X3gQs<%;40vl(-^b9th8plse; zJ|+f6hFtzA0jP*zlwb-&3R{j)u5gqvR7@mF1ndIQC^0ZC9wqM1kiwqA(ZZ0zk;<MW z(aanr>CTYCnZnh=kiwNJ)yy0vohp;Uoy|0Xu_y<_wX#vNsVocR7BXbXr|?KJL@A{3 zrYg*4Na35y9Hj_l^Ur0DQi8Gt<}ybqL)n6JnWI!vB~z7Bl~aULgi}OPM0;7IR8u)q zRa1CU#Cn;c*i#sy)KVCt)Kl4A7-9{gG*Z=5#Iso^Fc$eSGE87B%t_UVVUE&_(u&ef z)k)QCW@KbYRZrnDXNXdtz*uON!Whh;De)2%7k-+Iw^#xaOEN%oN@7W(CgUwGm&B4p zpZvs>)FO~*NnuK^CgUwu&)mfH)MSt*WXuBP6h|>IFr+d>F{Us?F{Lm@F{g;MGo&%5 zu%xiIa7MADu%)oKFhsGYaHMdyFhsGXaHVj!FhsGZ@TBmzFhp^r@TKs#Fhp^t2&4$M zFhp^IJl?_(#of-p!Vtw1%%CZHOVGo`&D7V^2jsdOu<H_kF)}a&r4|&W7N_Qw6e}d= zDI_KpmlP!?mneV~D}X#yoLZu%;Fe#cP@I{Uo|CHJVXBavpI4HYnU`9msKBM5pa3B} z^9o8!6bcfH5_5~Kz``J@w9K5;_=3ce3@ZgtC@GZWE0koUDx~G-q@)(X49d()2a7_j zQ_#rIOI64(QphdMDakAV+3DhDssJ*oSW^#fbZ&l1s+EFzNl{{Eo;pamdP!<=i8@>! z<i~iBkyZ+xX$m1lrKvg!AhBYF#GD+seF~Ymxv43ci6yByl?s`8CHV?Lj=p*dp~a~R zB^jB;3Ylqe13~@)IUZzYacYS|X0bwAW=SzbS8i%<eo>{Kf`3UyYEgM+G03vi#4@NU zi6shYi8;lo3W<3skl+Ft2Es4{70NSna=`9W$Vsfq%&A0p6l7m~a$<5uYJ72KRjQSO zUukYqYLP;I8YGs$0i#fokyxUToS3JOl&Sz$oT8AK2lghIr=XFNnwD6aQ(|kZsmBHO zliN!X1_p*?P%#2cgbbhphl7EEff-b+<S{ZZlrS_iEMQ#7z{rr!(7_nblnEl4!6Zu! zTL)u2YYlS;V?0|8V+UhAdks?uV?0L<O9x{-X9`G!D}||+y@b1iF^eOGxtXb7p_ZeB zX8~^w$3n&$W=W7)Ea40(3@i*R49(1p40!^E48>&z42)nH$-u|}PX7!Yj9Dx$49y@D zSxWdy_&XT0*qa#_2&AwrWNc>agcdX%3?Ma~4DAf<jA=|MoGBbF93^}py$b{vGBh)S zWIGt+g-V1=L^>F=gqj&a8M{Oj!ggVZm8j+HU|b+pB3{F}Kmw$rhIJtmBSYbv63GQp zC6WuI7c!JcE|5v#TF98rv=HPPS%^%QTncvzPcLJMWR^UPwLl?-H-)c-VIgBWGgz%6 zSgleDKNt%1GM7j$P+rIYHAy9fA(%l^(C=jh0|P_IEyki`Mp&dk;{#Mu1cTy5hJk@0 zlc9zoRwRb0ma&$pgrSD9nK6s8hG{ZWAxkjBN(N1)A|?g~hFdJfC8@cZ%(s~I3~n)2 z++wUO;$~oA029Bg^)vEwQ}ruz67%%S^NX^J3lfu4^+Sph^NM{vee^+zt~j-%SU<V6 zxFkOpoRL7eC$&hgpz;=%O-g2RNpgN}ft@)61H)&KQ>*kqWd=$CqZc2anU`4-AFpSV zlb@WJQ*5V)&{)N$rI(SCW{Tobkh4LF5A0DP1_p+7kVm;{8EY7`7$$<f2X?+DV-XJn z1H($DB3_V>_&@~QZ$$zi&vV)2WEMl*=n9esnMR}=;TkdA#>&9JAPkBGHJIDvKyK?` zC}SvMDFG!^#%9JE#uNr|25G2!H5sc|!5)Rh3MVMTf(ykWkWnj{qIf|f#qpVWV9{GF z#i==IU~htbQUuDTelRb*WME`qcnL}znvAy?3s77L5`~8zC@+^VEMTl*SjZU6u#(YF zlj#;yUcoKa#G<0aN|5G6h#lZ?0|yz0O-^ENQc9v-BFwfb9ik!?s!fwA^Yj1z|6hXQ zM3eaze@1C)d~rr*T1k9PW^QK5E&k%f5>$b*#2jP+P1Reh#ia$QMYlL$N#_<PC@IEg z=B4G|;sP_l<=HL1;+)KsRG-A8)Et-0<dR$b2sv;ud5Z;9K@^FClR^$i86<Ey^Wu|p z5{rvdi=#M_64Rk-^9xe*Zn0z)=jYvG0VVTWETA+1$wjP1={fmHw^*`MD~oTjf}`pd zYe7+FUdb)yoYcG`P@1^KAD@_#0?s91rxxF0E-op$#a@(JoS##cdW)+Bl(EwCi*k!^ zu{-DI7UUPF-r@?$5AgH>w+n8u1*L)t_gjKRU^*U2y+}}EUP^whM}ARe6{zOTam&dp zxFrM;4NEO5$xKd!h_U&Aqwf}9W^OvN2@(*s9;K<li6x~)i6xo&d0~k;rKz_#Lh?gW z^NRC}Zt-Olmm%rV<SG&eB?3_R6={Q5S|CCP6r7x(5CZ$PNE6I2N=+^;D$XoRy~UUo z#gkzQii+aYlK8ZOm5jGI;~@baA73N}G6Iy+ic}dG8010u$O%+*Fmf@mF~T8)&BVga z$Ee37#4Ezc$Em=?$IQnh#>~RV#>nxHgO!Dmi;;_o=MM)P3nLFB+rKI)+{rW<RHlGz z0bx*O$Hu_G;0&_+3<Cp04Z{M48pef;G0e40wam3F=?t~3H4IsdDU8`nMIJRwE)20s zwQMyE3z%xy7BYfa%up8V0u~U9k)cp2PXVk_s+O&UHH&QlJ4n2itp=gOr-Y-1wV5%6 zDVwRNt%P#{R|+$zYD-~V$WqIi$5q0)fV+luA!99T4QmPK0-hSyEZ!6*Nrr`tk_@#V z5xyGcg^ab#C7fCOS!`JXHOw^(S%N8Sy-c;tHLNwvH7se&!3>)0ez#b`?V%!2N$@g- zfq~&As5~f=U|?X-WW2?lSDFinjMQRHp<5j3sU?tT1jnr-D0W;x5z3ZWP>`CJQe+9@ zaHW;zC4*8-V$LnjlKl7*aC#{+1<A3N<QFAp++r<C%uC6;#adjFn4Eo!1H>yxEy}&c zTAYzska~+Vu^6iB7E@l{E%yAp)cB12k|I!I(qspxvs<j0c`2zCw^)-BOF#yKGwdyv zyu{qpD9-%Sl6Z(Xb7pQjdZEP~AD^CDl39|P8y}wyidH31L^BC6vM_NmbAS>GBOfyd zBM%5N2{CdpiZF38iZE8mVoNe`lQ7CvP$~ka8c4YcN+ZpT!3;&xprB;}2S*X8QqW`s zdkEc??D6qAsd@47Z6Ftd%wk}ylEmf$u&QKGtbx1%!k{z@!eHA$9X?ovU&{<m2TWxQ zMK(1ISxlfN6(j|xfKzG~^8%I{#u}Cy<{GvXra3GN89}|ZU<OTQznB02|Np;|1zNa* z8t6qapy-bU#VRMbgpJQksnBEv7lGiqwnzh{0OZ9Ye~_Oc5eF_{i$E1@6i0b+W;`gy zz((ET1aV=}mQsNhAVp>j3=Ey1zyQ_b4B%+uVr1cE;i^)^5-g~ND@vB;U|?VXVNj42 ze`5f(Sr;&NFr+XpWa{s!WlCqLWv*dbz_gITg(23YmZgTNhDDMgg(-!pl?jx&YFHOA zm#}0(G9SYNHfT0vUce4wF)|cN)i5mJsA2A4Sjbe%R>BGHl`*HVv~ZMg)i7jnL)w+i zjJ51FY#@FOdkvd7gCs)=s|`aqLkbg0+jbsu+qQ(Kh9Qf$nX!f;i!Ym{=u-(lgx|^7 z!MH%6g8>v>Adv-v9gHb#pu&I=)Uss+cdI!-c7a;7H4O3GDNMl(nw(WtpiU>a(+2Bj zfg@5OJ+mwo(iH$zUf{MrsOtfWb#RBIv>+w11k_W3_8b(zc@f%ufGY<T>fpLfp}3@| zG`Xa-C>7C1DFWpRKa@-ZiVqMr21PO`Pu4KRN<mtzHB2cCk_;UT*-S-DHH={lj0}ZL z;7&DD5vYUg2Tm)R%;2`&Eym1SjAfdPkc5V4t8m(+r{<+r6cpKQ2Bk}I>rNj_%7NEa zNNrG<&LU6=2`(#9Tn@^fpbQ6UofTWb8hS>ch91P-DU6aJw;O>PdXfw^jG&^rhB1X% z65L;z!;;5T!w6}tYO;ZS1hxa5Es8)ntO(q`N&>kX<RMK+EA|#sN~R{`EuNy(+<Z{G z2-MyHhXkS(cmbpZ>{SCIy$Wi^fZOSsLbq5lObaS+vAAXCq!t;0GCEsgaY<!CY7wYM zaf`JyGq1#=$QEP;D1+bP1vhC*@*#cCTP%606(zSgQ}ar5Q;QNyQo*Giq})>n=>atb ziVVQUf_;#Bi@6}bK$8Pf9D>>rQS2F}@t`DFlnPSDQUY>F97t~(0|P@82S^H3>VQ41 z$pndUP>~J}fLmPg@VHzCN>g&6!jYAa5!3__VB}%s0+sNLGK>bI0!&<tT#N$DRWkTO z38jz)rBo0G7qXx@uVJcT$YLmCERv{UOb6vrh8iXthCBg~XbqDjsA0_r%AqA-bs&*u z#uP>gP%n<DYNBIGN->t42<rY9rIwTy<rOPHh7M9vEA$jXGE$3DLE{Oafe&z=h37g@ z{3)d67b&Ecr506!2SLE?SRLfS6VQMLD0_lBY>;dV8gR(WLC&C(pfa1e2pm<Q=mp1| zCNm^<Kt6{gR8V>Zhd(&XiwYPR7%qVV8&twFuy8T*F$pkM$zustPzeo7D=6hN$ZSxK z0*5Rp?Sey)QIY|ai)xu_KxG$`A2?rVGJ+Eo%vI^2qyf$^sIG!ETtGF=BaqWT9b*P2 zK1OVfKu`+@q!7hLpavkwRbUq-fE$5X3^hnD0#)v1j73V|W*-wMmf>y&)fu417eqa? zBttC=T#f}?pD=)W;*t!Y%D9H5h85f}W39U542mT1kP@strI46cS(2HXs!*9<s!*Po zR{|=v6O)rui;EQ!6_PVb^Rhu>JfMhGNGr<E1r6rF(*>vi2g!hk=R!auRp0@n+{Elu zh2qj8P-&i60@49Vg_*_Rv87b-pipvRajHTlXaEbvxV(Ju_+1IeQuweGI5i^5d62c4 z3bv3M0<Ig9bijovycMCz430}kf-3_hJ~l{N1I6ksrnCY`;R{y}Np#?d2KNv^9p2}l z=mqud8QA$iEd&-mMm|P9CR`B>3J*;tw<1tl)#L?rAwf|eAAgH0K0Y@;r8FlsKK>R@ ze0*VPVh&V>Jw84qKRG@g+#xB72PHUAf&jNpia@Rg=YXPekOas@w|JqwndH=*ocMSo zM|*&jfL#u+IzWk?gOP=s&xb>rgPB8?Lx)3vgS7}G531%s#TOid5;3T4R}3oE;N!rY zpiV*yXfzly{>zfW5zL^;RRroatz^pd0d=;wfX6{H8A0s-3=9mQKx1HFcwnCl>X^bT z0I}h@5LDjRFxD`{Gk{9Dct%h&GM=f1rG_D%88i+BY1Fb5DT4e2s=AAmK`bdyEV4pk z5sMU3V1UB`)W`b^ia}6(FhF{Bj8%%b`l9f`B9y$!$-uw>ayvM$f~vn7@Q7CyBWP3- z)acD-C^}ODYUnbRFlVtYWU6H<VP3#i!vyZqG1s!xvX(GpF)m=QVaVbDl{jFU3Dg1< zhqUMz7H~qlYs@ul<_xvWVD;d(9aKFxsD57v8p&i{z*ECg!@iKImIbbg7pe-<VPmT0 zC;_Qrt>IY6RLfezynqkXJp+{k3qdZ2xs)H|QgGSF><4KbfD32=NL|bZ>O$q`7J)+i z7ISi$Q4~jVS&4IgPJU4oud~08e^7joyOWVeaM&%@;)0yal3VN`9hrG0kT}Z+)vxTO zd6}RQhayl6-eOKnNdZUWE!N!BqV!alIUd1bMnUdQnw;Qxy~S3Xm|KvOS_G<=Z!u=v zVk|Do09DSM@UfgIPWUKJ6ep;i0v_Q(i+oV#OHKr35f(-UhGI~HV_;%o5@6(G6kwEN z<YMGtlw#xoV~`l$Q~@!^Pm`sn4wMQT;6xK7JywDQsz5|FhyW$KqFxXST$msTaH0g| zoZp~C3CeIBOkC{XyaaM?ZenI$e0&io8h$bASLr~8MqN@<3*3tm(@NqCit@8klS}lI z^HWmwQ&N*k!9y&Wd3wpkWmURps$t`&`k)a`eIrw|cq3D@(%g7(Wa(E~qALgY+B5Ug z<4Y2ga#D-+!Rq22L4*Ez@d1v(@rgM(dZ6K{Dm`?qpi-|mwM0KTBQ-f2!>S@L1_lPV zm!K+7lLcuY2h?MObif&Jae;ej;E}MGpmw~bZcz>>0kP(l<`z^!QY1U5w~&*W1nCb3 zfszC_xbv@9T9TPlTm-65Z%HGG=z-h?CP1yW)S@C#fP=@%qQt<;$})>H^Yio&16|<Z zhA3gMq+VX09;}{B%qfZhMP~rWO^j~0n3IdkZZQXV`V@hDRn!6!W-Ca{OGzxg#Z-_| zqz4jU2TgMor<Q<A+goh;Nuam`$8`}Xz1(8TFTTZ`T%K}^wIH#mr1%yKXf!;E6C4lm zAPXSFjz#(TCE!rI#hqVTQczj~ZhYNh2YC);I;e8G#h84H1vLH;#hjFwj?~IG0l6BH z{=h6yBD}?61F4SeK*Kb}pdl9yCeR242NPtJf`?Iv5zJ@gVB}yDs%OX)kOj4FIhgoB ZZCsEl79mC+W)?;kM$iZc8z@6b0RSRBgwFr~ literal 0 HcmV?d00001 diff --git a/datasets/__pycache__/data_interface.cpython-39.pyc b/datasets/__pycache__/data_interface.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4af0141496c11d6f97255add8746fa0067e55636 GIT binary patch literal 5690 zcmYe~<>g{vU|>k>vq+jC#lY|w#6iX^3=9ko3=9m#GZ+{cQW#Pga~Pr^G-DKF3PTE0 z4pT036f+}4j3tUSg&~DGhb@;qianPjiUTaplEazH6~zT+v*vK;@<j14GNiDju;=jR z@<s81*&I3ixdKrFU^SdMg1JIbLW~UV3@KbG+${_#+^MWt!p+Q4BJK<+JSn^_3@N-& z5z!RJU<OUTmmoL$X)@m8@GmII%+E{A(PX;CQ<RvOlAjx2T#%Dla*N9)u_VzaKQSe> zNR#mvtFNDDaEK=3Ev}NH#Ju9P{G#0Amy8Sy44RC$1l&>+OG=AUof6aG>IFSq+)RBv zeL#lfzzkstNG!=n23d`aS)rWb2nGg*R0dEmMKPs_wlkzLrm&{4wQxo;r?97Rv@k@m zq==<(rf{_|MzN-Fr|`5eM6sptrtq~eM6svvrwFt#L~*1DrU<n#L~*7Fr--yLL~*q< zurNe%2Qz4j-{J+i#WSxYwJ0qyIrSwdmR@o&FfbIcGcYjtC9^@D0ir;{%nS?++@P2Y zV_;w?VQ6Mpz_^fsk)ejMh9RD*hN*@jp1Fp(h9RD%hNXrfp0$Rxh9RD<hOLGnp1p=) z0S8EJ3FiW?g$!BTDU6Z~wd^TOHSDt)Qkdt0bn&FH1T$!|R{e-BF3rtNEUJv<;!=PD z$D;IND>x4%n3PzOoDpA~S(U1wk(pPbqmU1aH7kYal+@znqD-(rte%2PYFc7xPD!ys zNxp)ai5`kEd8N7W<@rU~sYS&knPY)!OiE%&VsUCod|qO1s)9ywNfEK`QCCM81P=9I zjQUmX`jt6}dHO-l&hZ6B`B|ySCHh6l$??fKiN(d4X_?81p!BH^wX;~?IWadir!qg! z&<u&Gmt0)tr^$PZwYan(wdfW{e0*kJW=VYfErF7v#LT>SMEKnjgoq#o#V!7l)Z!9k zH3DD(WF<UM`@w#`#g<*3Sd?CTO8}Iz<3UpKx%nxjIjKdU!r>Nka(T)v)`G;MlHyyO zVA-Pl{1Q!$TP($?IcY_Lpfo20B7{MN2#62`5n>Ds47XS^@^e#HGT!2hhXg@<{7Qyj z7WyDjNI-!iu(%*GIaNOdloNeDeGq{Iwk#8t#Ptd)iv$=L82CV`7-S6t2Nx3?BLZ?T z@-bG4p=1j^SkNVdGB+qMfG~&+!p<O*Kt)^%LokCTqhFCc0|UcK#v*Bu^FeH|K#>B- zBRmB~sRfBeso?NY0cm6c8O2g1g=!(PvShFlkTM1a22i;GvI}f!2Ll5`4Py#pHdC=c z4Py#ZDML|82}2D-7Gnx?3QI3jEmH|o33CloGh+&?2tzYtEi;T?!z=;fvy`yZFgG(Y zGL*2?uz+dS8m1cN8kTgXW~LZWc7|kK#uT<}rs5K)9W78h*g<yGFvN4DFa<Mca#WpC zP=MzwU4`WQ(!3IdywcpH)FOraGzD<lO)W;`XkCS*(#)I`g|htQ#H7-k#G*=Mxq{Rp zP<g5VPMP{;i8=b9^rrwWhEWVjEGWpS1Sx|SzzP~EscD&csVNGn6$Lq&$(bcNl?s`8 z3YGb#MGBy_07;b!`9%t#%#;pNnVF{m${U(+=ig#Z%PjE=(PS!;0!14u*cH!b%>;p$ z_P1C-ZgJD(ECQ8@5GBl+c_p`)b5iqeu@)3%=9LtIippEOV2$9|E>11E#a)(|1LyF7 z41w`9S;6V$7He@yVtVQ=c926sIv5jgF(yG05I8A;QxR)%YDsB<7AX0EoX5Z;z{tkP z^q-AMfKh;vgOP(#j7@+EBvK`enykS3FcLASkO!wHP+o(?Z4E;fQw>8Fa|&Y$Q!l7c zX7&qN$$X0osm$O;E(+L-GfLCaa#A%}ia@Ee2;?@TJgtZ&D?m#jkVio!Mrx54$Uoq; z&R8Xf>TX0RLX?u@g_X>HZbe3*${FfLZIGi7<qkGC^1<Q@Vy!Qd3uRDUh^mkr_Z68V zxlIS;HbhB?&25kZ8WDT}NRE?5bsRjQLA-(DI#7Va6F>;4ea5(eA%!W0xdmL1EMQ7u zS;*+Z(9F1yv6it)u7tUSWdUmnE2xl}%`lf~A!7|=Gq|7#XUG#_WB|bu_7t{e#wNx{ zhCC(%25_m!?stm;!&*>E2?n_uR1Bvx)G);I)H2pEE?`*5Fp;T{C71!6?m^9~<ivu^ zlEj?Ms#Hy;TTFTew-__Qau5QXC%9~KGK))+^K%RALKzqsK7-23DkIcz0C^vr;h;5> zo=r}Ea$-)gogPB-EynmNaQ#%AT2ic6W#E}tmY9>7q5v|fIJHE<Ei)%o!4Fi)=NBo! zB%ML^GFVg*RHJ3(XXX`wYOh<Y;PUDgdwOa~Vo6ESEzZo`g8ZVAoXn(KybvbTDYv*m zF$v>vg(c>crn(jt<rir(feVHxu4GUh53Xt<zJ?bEMdAz$44{}R2IXG{Mj1vXMm9#S z|4fWb|3Ot7h(^y)=plg;H=yJIiW_fG+$4Y-dJ7mqB|;YSLdGl>P?^xmB*{<<Dj`@G zuq|Y$WzJ-%WvO9uVThHg1vUDZ%NUBZY8bLuYgi>2YFKJmZ5Rr*YFHMq*Dx+*WCSHF zrZk33h8mV7jD4V{H*?iV1<$<VlEl2^OmH+SBr4?Qr=;d6lon^^r7M(Vq$(um7Zs%z z7v$%qfJ6!sixP8FOHzx9;bk_cO_)}cpQ`|BQ0QfVRTL|vWF~{!N{L04dJ2*Gr3%T3 zc?yX+#rX=Ec?G2<3W<4@3ZP~SiYX=e3i)NJMdd}AC8=2KE(1jls7e7_1Zno5wc-3U zIc~9M<`oyDCYRje0>yq&I*3!`2Flo=_Qx$2khMjiD7(dyomyFZi><UEC9xz`lMNc- z`6U^tMe!h2QC#546T)W8%mXQpVg<<<-eOKHN=Hi&w|Fy=!!i~WjUX!-SS1*_7+JtI zga9KKW0f>0QNtVmdMLUyS&Mi;Zsi3Lpw0leZR)4V=%>j8YD<FJx$*J0xZ<H{Ha`9q zPkek~X<`mU2HeOhQUd8V1QBK+!U9BCf(THWE%F1^t-R1SU2<wpPJBF)S7brzKzX|e z)Z+j-rWoW44kkV^4mK_Y4rUH!4i*khuq;OZ3FKl>5l~!!+<y{IWlm*IWl3d8Wldq7 z!wT+Av8AxL@J6wvh@`ToGD|X~aHMd~VFJs5x{bV19I2csJaZVqeJYkHE+`Mwvto(j zhVnrDE0!o8aKDN-m_bwY7B94$=L=~9gPc<Y3d)zD@-{@171C(n%1KO0&518aEXgP` z26@*66e6t2iOCtM$sj!-FMu#70|Ns{CqF1IW58LGP~R&B)F)#D_rO5CEKoBI)W_mX z0rkwdKqYuQcM7O^$CJWd!w}D#!coHz&j&88IsGyl7#SE|f^2;WYP}bMikFus7#J8{ zf@)B3QV7xHF9JCS649V`ND-*uMff9%D>FASJr&$TyTzFYX}PBsNAZ><7NvuFO5mnp z6n}1NBB(Qx9-ot%mtK+)#aEVC6rYrc;DZ_(Nr~yj@!)1a6fc6GT3iC^KcuD<f%>DG zoZy@T4q7y~6h(tV7o-T>%m8P*B5(pj6r-TTl?w`A8BmBbav}8z5WNF7MmDA@F_g%K z_YV9tnTm8l5y)1QT2fk+hthBd6{;W%Y8Dnx0LLJx@0`U1?gul0n$~3uMM)*hH4IrS zprSJc#9~Wf>1C;9E@3a>05#O1&3G0VzlH_ejAt$31U1@gSmDift{R4T?i$7vwlapI zCL}Z2YZ$V?X4Wvo^OW#5Gi34AFs5+yf(l_yKXANiGWiuLgF*|`J<((a2jndlP-zSa zL{(5Tg~_i-4WtG{XflF*t;rAW=oLwT#3exlDES~o1~>}TK@u7u0u&IpKm#Y4dEh$v z78kgpg|Io{O%_NR0~I0Af)Q+45h#v|^g!n5g9rl<ffk=2J@ugY1a&JJxcS&P7{!=C zEo3n!4n_&4Dq+-E1h<Vf8H+&0B}#P&>f?c8j)*3*CQ}h82N#)w!U)>Tj0ZUlp1F&# zH06pwW#b%>OF>zjfw4*s)n(vh4-OYlwVhgo;W|*K5bQcoLkhhC-^(1#u#yRsHsOwd zwbDU#J0ibgbsDI|o{!`-8C0hcX!a6xSWyyENF@>uDNx<92+29JsLp|>8c0BCvVc=G zxZnjBVIiREi?7HZWD%&nToee(uFS<HMMa>5d5a|}KR*Xjpn;;I2$X~%9U4#+6@dby z2;_vKAdm^*q618TQy#dg0<~weK@J2Jn;c9sLQ=AP%zVs6AaPBxA`g&07m%fFC6xu4 zdFe$Udu|C9RF>oyC1=FvWTt17<Ynfi-(o7ri4p({>Xnvc<`nCJ+K1pyCD=hxB4D|) z%;L=aJg^+7y;=kc@F;F1sUlFyjS_$~9zjMxLI&J%jS@^QEiTE=MM_#lpgenv9aJIZ zWF~<#bd<0U%u%4`0<=N{RmosaAywGmpa6x!Ee;z<C)*B`cfie3P_LASk%Liykp~8u GgyI3%g;R?F literal 0 HcmV?d00001 diff --git a/datasets/camel_dataloader.py b/datasets/camel_dataloader.py new file mode 100644 index 0000000..302cabd --- /dev/null +++ b/datasets/camel_dataloader.py @@ -0,0 +1,126 @@ +import pandas as pd + +import numpy as np +import torch +from torch import Tensor +from torch.autograd import Variable +from torch.nn.functional import one_hot +import torch.utils.data as data_utils +from torchvision import datasets, transforms +import pandas as pd +from sklearn.utils import shuffle +from pathlib import Path +from tqdm import tqdm + + +class FeatureBagLoader(data_utils.Dataset): + def __init__(self, data_root,train=True, cache=True): + + bags_path = pd.read_csv(data_root) + + self.train_path = bags_path.iloc[0:int(len(bags_path)*0.8), :] + self.test_path = bags_path.iloc[int(len(bags_path)*0.8):, :] + # self.train_path = shuffle(train_path).reset_index(drop=True) + # self.test_path = shuffle(test_path).reset_index(drop=True) + + home = Path.cwd().parts[1] + self.origin_path = Path(f'/{home}/ylan/RCC_project/rcc_classification/') + # self.target_number = target_number + # self.mean_bag_length = mean_bag_length + # self.var_bag_length = var_bag_length + # self.num_bag = num_bag + self.cache = cache + self.train = train + self.n_classes = 2 + + self.features = [] + self.labels = [] + if self.cache: + if train: + with tqdm(total=len(self.train_path)) as pbar: + for t in tqdm(self.train_path.iloc()): + ft, lbl = self.get_bag_feats(t) + # ft = ft.view(-1, 512) + + self.labels.append(lbl) + self.features.append(ft) + pbar.update() + else: + with tqdm(total=len(self.test_path)) as pbar: + for t in tqdm(self.test_path.iloc()): + ft, lbl = self.get_bag_feats(t) + # lbl = Variable(Tensor(lbl)) + # ft = Variable(Tensor(ft)).view(-1, 512) + self.labels.append(lbl) + self.features.append(ft) + pbar.update() + # print(self.get_bag_feats(self.train_path)) + # self.r = np.random.RandomState(seed) + + # self.num_in_train = 60000 + # self.num_in_test = 10000 + + # if self.train: + # self.train_bags_list, self.train_labels_list = self._create_bags() + # else: + # self.test_bags_list, self.test_labels_list = self._create_bags() + + def get_bag_feats(self, csv_file_df): + # if args.dataset == 'TCGA-lung-default': + # feats_csv_path = 'datasets/tcga-dataset/tcga_lung_data_feats/' + csv_file_df.iloc[0].split('/')[1] + '.csv' + # else: + + feats_csv_path = self.origin_path / csv_file_df.iloc[0] + df = pd.read_csv(feats_csv_path) + # feats = shuffle(df).reset_index(drop=True) + # feats = feats.to_numpy() + feats = df.to_numpy() + label = np.zeros(self.n_classes) + if self.n_classes==2: + label[1] = csv_file_df.iloc[1] + else: + if int(csv_file_df.iloc[1])<=(len(label)-1): + label[int(csv_file_df.iloc[1])] = 1 + + return feats, label + + def __len__(self): + if self.train: + return len(self.train_path) + else: + return len(self.test_path) + + def __getitem__(self, index): + + if self.cache: + label = self.labels[index] + feats = self.features[index] + label = Variable(Tensor(label)) + feats = Variable(Tensor(feats)).view(-1, 512) + return feats, label + else: + if self.train: + feats, label = self.get_bag_feats(self.train_path.iloc[index]) + label = Variable(Tensor(label)) + feats = Variable(Tensor(feats)).view(-1, 512) + else: + feats, label = self.get_bag_feats(self.test_path.iloc[index]) + label = Variable(Tensor(label)) + feats = Variable(Tensor(feats)).view(-1, 512) + + return feats, label + +if __name__ == '__main__': + import os + cwd = os.getcwd() + home = cwd.split('/')[1] + data_root = f'/{home}/ylan/RCC_project/rcc_classification/datasets/Camelyon16/Camelyon16.csv' + dataset = FeatureBagLoader(data_root, cache=False) + for i in dataset: + # print(i[1]) + # print(i) + + features, label = i + print(label) + # print(features.shape) + # print(label[0].long()) \ No newline at end of file diff --git a/datasets/custom_dataloader.py b/datasets/custom_dataloader.py new file mode 100644 index 0000000..ddc2ed3 --- /dev/null +++ b/datasets/custom_dataloader.py @@ -0,0 +1,332 @@ +import h5py +# import helpers +import numpy as np +from pathlib import Path +import torch +# from torch._C import long +from torch.utils import data +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm +# from histoTransforms import RandomHueSaturationValue +import torchvision.transforms as transforms +import torch.nn.functional as F +import csv +from PIL import Image +import cv2 +import pandas as pd +import json + +class HDF5MILDataloader(data.Dataset): + """Represents an abstract HDF5 dataset. For single H5 container! + + Input params: + file_path: Path to the folder containing the dataset (one or multiple HDF5 files). + mode: 'train' or 'test' + load_data: If True, loads all the data immediately into RAM. Use this if + the dataset is fits into memory. Otherwise, leave this at false and + the data will load lazily. + data_cache_size: Number of HDF5 files that can be cached in the cache (default=3). + + """ + def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=20): + super().__init__() + + self.data_info = [] + self.data_cache = {} + self.slideLabelDict = {} + self.data_cache_size = data_cache_size + self.mode = mode + self.file_path = file_path + # self.csv_path = csv_path + self.label_path = label_path + self.n_classes = n_classes + self.bag_size = 120 + # self.label_file = label_path + recursive = True + + # read labels and slide_path from csv + + # df = pd.read_csv(self.csv_path) + # labels = df.LABEL + # slides = df.FILENAME + with open(self.label_path, 'r') as f: + self.slideLabelDict = json.load(f)[mode] + + self.slideLabelDict = {Path(x).stem : y for (x,y) in self.slideLabelDict} + + + # if Path(slides[0]).suffix: + # slides = list(map(lambda x: Path(x).stem, slides)) + + # print(labels) + # print(slides) + # self.slideLabelDict = dict(zip(slides, labels)) + # print(self.slideLabelDict) + + #check if files in slideLabelDict, only take files that are available. + + files_in_path = list(Path(self.file_path).rglob('*.hdf5')) + files_in_path = [x.stem for x in files_in_path] + # print(len(files_in_path)) + # print(files_in_path) + # print(list(self.slideLabelDict.keys())) + # for x in list(self.slideLabelDict.keys()): + # if x in files_in_path: + # path = Path(self.file_path) / (x + '.hdf5') + # print(path) + + self.files = [Path(self.file_path)/ (x + '.hdf5') for x in list(self.slideLabelDict.keys()) if x in files_in_path] + + print(len(self.files)) + # self.files = list(map(lambda x: Path(self.file_path) / (Path(x).stem + '.hdf5'), list(self.slideLabelDict.keys()))) + + for h5dataset_fp in tqdm(self.files): + # print(h5dataset_fp) + self._add_data_infos(str(h5dataset_fp.resolve()), load_data) + + # print(self.data_info) + self.resize_transforms = transforms.Compose([ + transforms.ToPILImage(), + transforms.Resize(256), + ]) + + self.img_transforms = transforms.Compose([ + transforms.RandomHorizontalFlip(p=1), + transforms.RandomVerticalFlip(p=1), + # histoTransforms.AutoRandomRotation(), + transforms.Lambda(lambda a: np.array(a)), + ]) + self.hsv_transforms = transforms.Compose([ + RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)), + transforms.ToTensor() + ]) + + # self._add_data_infos(load_data) + + + def __getitem__(self, index): + # get data + batch, label, name = self.get_data(index) + out_batch = [] + + if self.mode == 'train': + # print(img) + # print(img.shape) + for img in batch: + img = self.img_transforms(img) + img = self.hsv_transforms(img) + out_batch.append(img) + + else: + for img in batch: + img = transforms.functional.to_tensor(img) + out_batch.append(img) + if len(out_batch) == 0: + # print(name) + out_batch = torch.randn(100,3,256,256) + else: out_batch = torch.stack(out_batch) + out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch + + label = torch.as_tensor(label) + label = torch.nn.functional.one_hot(label, num_classes=self.n_classes) + return out_batch, label, name + + def __len__(self): + return len(self.data_info) + + def _add_data_infos(self, file_path, load_data): + wsi_name = Path(file_path).stem + if wsi_name in self.slideLabelDict: + label = self.slideLabelDict[wsi_name] + wsi_batch = [] + # with h5py.File(file_path, 'r') as h5_file: + # numKeys = len(h5_file.keys()) + # sample = list(h5_file.keys())[0] + # shape = (numKeys,) + h5_file[sample][:].shape + # for tile in h5_file.keys(): + # img = h5_file[tile][:] + + # print(img) + # if type == 'images': + # t = 'data' + # else: + # t = 'label' + idx = -1 + # if load_data: + # for tile in h5_file.keys(): + # img = h5_file[tile][:] + # img = img.astype(np.uint8) + # img = self.resize_transforms(img) + # wsi_batch.append(img) + # idx = self._add_to_cache(wsi_batch, file_path) + # wsi_batch.append(img) + # self.data_info.append({'data_path': file_path, 'label': label, 'shape': shape, 'name': wsi_name, 'cache_idx': idx}) + self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx}) + + def _load_data(self, file_path): + """Load data to the cache given the file + path and update the cache index in the + data_info structure. + """ + with h5py.File(file_path, 'r') as h5_file: + wsi_batch = [] + for tile in h5_file.keys(): + img = h5_file[tile][:] + img = img.astype(np.uint8) + img = self.resize_transforms(img) + wsi_batch.append(img) + idx = self._add_to_cache(wsi_batch, file_path) + file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path) + self.data_info[file_idx + idx]['cache_idx'] = idx + + # for type in ['images', 'labels']: + # for key in tqdm(h5_file[f'{self.mode}/{type}'].keys()): + # img = h5_file[data_path][:] + # idx = self._add_to_cache(img, data_path) + # file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == data_path) + # self.data_info[file_idx + idx]['cache_idx'] = idx + # for gname, group in h5_file.items(): + # for dname, ds in group.items(): + # # add data to the data cache and retrieve + # # the cache index + # idx = self._add_to_cache(ds.value, file_path) + + # # find the beginning index of the hdf5 file we are looking for + # file_idx = next(i for i,v in enumerate(self.data_info) if v['file_path'] == file_path) + + # # the data info should have the same index since we loaded it in the same way + # self.data_info[file_idx + idx]['cache_idx'] = idx + + # remove an element from data cache if size was exceeded + if len(self.data_cache) > self.data_cache_size: + # remove one item from the cache at random + removal_keys = list(self.data_cache) + removal_keys.remove(file_path) + self.data_cache.pop(removal_keys[0]) + # remove invalid cache_idx + # self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'shape': di['shape'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info] + self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info] + + def _add_to_cache(self, data, data_path): + """Adds data to the cache and returns its index. There is one cache + list for every file_path, containing all datasets in that file. + """ + if data_path not in self.data_cache: + self.data_cache[data_path] = [data] + else: + self.data_cache[data_path].append(data) + return len(self.data_cache[data_path]) - 1 + + # def get_data_infos(self, type): + # """Get data infos belonging to a certain type of data. + # """ + # data_info_type = [di for di in self.data_info if di['type'] == type] + # return data_info_type + + def get_name(self, i): + # name = self.get_data_infos(type)[i]['name'] + name = self.data_info[i]['name'] + return name + + def get_data(self, i): + """Call this function anytime you want to access a chunk of data from the + dataset. This will make sure that the data is loaded in case it is + not part of the data cache. + i = index + """ + # fp = self.get_data_infos(type)[i]['data_path'] + fp = self.data_info[i]['data_path'] + if fp not in self.data_cache: + self._load_data(fp) + + # get new cache_idx assigned by _load_data_info + # cache_idx = self.get_data_infos(type)[i]['cache_idx'] + cache_idx = self.data_info[i]['cache_idx'] + label = self.data_info[i]['label'] + name = self.data_info[i]['name'] + # print(self.data_cache[fp][cache_idx]) + return self.data_cache[fp][cache_idx], label, name + + +class RandomHueSaturationValue(object): + + def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5): + + self.hue_shift_limit = hue_shift_limit + self.sat_shift_limit = sat_shift_limit + self.val_shift_limit = val_shift_limit + self.p = p + + def __call__(self, sample): + + img = sample #,lbl + + if np.random.random() < self.p: + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32 + h, s, v = cv2.split(img) + hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1) + hue_shift = np.uint8(hue_shift) + h += hue_shift + sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1]) + s = cv2.add(s, sat_shift) + val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1]) + v = cv2.add(v, val_shift) + img = cv2.merge((h, s, v)) + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) + return img #, lbl + + + +if __name__ == '__main__': + from pathlib import Path + import os + + home = Path.cwd().parts[1] + train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv' + data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json' + label_path = f'/{home}/ylan/DeepGraft/training_tables/split_Aachen_PAS_all.json' + output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/' + # os.makedirs(output_path, exist_ok=True) + + + dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=6) + data = DataLoader(dataset, batch_size=1) + + # print(len(dataset)) + x = 0 + c = 0 + for item in data: + if c >=10: + break + bag, label, name = item + print(bag) + # # print(bag.shape) + # if bag.shape[1] == 1: + # print(name) + # print(bag.shape) + # print(bag.shape) + # print(name) + # out_dir = Path(output_path) / name + # os.makedirs(out_dir, exist_ok=True) + + # # print(item[2]) + # # print(len(item)) + # # print(item[1]) + # # print(data.shape) + # # data = data.squeeze() + # bag = item[0] + # bag = bag.squeeze() + # for i in range(bag.shape[0]): + # img = bag[i, :, :, :] + # img = img.squeeze() + # img = img*255 + # img = img.numpy().astype(np.uint8).transpose(1,2,0) + + # img = Image.fromarray(img) + # img = img.convert('RGB') + # img.save(f'{out_dir}/{i}.png') + c += 1 + # else: break + # print(data.shape) + # print(label) \ No newline at end of file diff --git a/datasets/data_interface.py b/datasets/data_interface.py index 3952e5b..12a0f8c 100644 --- a/datasets/data_interface.py +++ b/datasets/data_interface.py @@ -1,9 +1,13 @@ import inspect # 查看python 类的参数和模块、函数代码 import importlib # In order to dynamically import the library +from typing import Optional import pytorch_lightning as pl from torch.utils.data import random_split, DataLoader from torchvision.datasets import MNIST from torchvision import transforms +from .camel_dataloader import FeatureBagLoader +from .custom_dataloader import HDF5MILDataloader +from pathlib import Path class DataInterface(pl.LightningDataModule): @@ -24,6 +28,8 @@ class DataInterface(pl.LightningDataModule): self.dataset_name = dataset_name self.kwargs = kwargs self.load_data_module() + home = Path.cwd().parts[1] + self.data_root = f'/{home}/ylan/RCC_project/rcc_classification/datasets/Camelyon16/Camelyon16.csv' @@ -46,14 +52,23 @@ class DataInterface(pl.LightningDataModule): """ # Assign train/val datasets for use in dataloaders if stage == 'fit' or stage is None: - self.train_dataset = self.instancialize(state='train') - self.val_dataset = self.instancialize(state='val') + dataset = FeatureBagLoader(data_root = self.data_root, + train=True) + a = int(len(dataset)* 0.8) + b = int(len(dataset) - a) + print(a) + print(b) + self.train_dataset, self.val_dataset = random_split(dataset, [a, b]) + # self.train_dataset = self.instancialize(state='train') + # self.val_dataset = self.instancialize(state='val') # Assign test dataset for use in dataloader(s) if stage == 'test' or stage is None: # self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform) - self.test_dataset = self.instancialize(state='test') + self.test_dataset = FeatureBagLoader(data_root = self.data_root, + train=False) + # self.test_dataset = self.instancialize(state='test') def train_dataloader(self): @@ -87,4 +102,62 @@ class DataInterface(pl.LightningDataModule): if arg in inkeys: args1[arg] = self.kwargs[arg] args1.update(other_args) - return self.data_module(**args1) \ No newline at end of file + return self.data_module(**args1) + +class MILDataModule(pl.LightningDataModule): + + def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, *args, **kwargs): + super().__init__() + self.data_root = data_root + self.label_path = label_path + self.batch_size = batch_size + self.num_workers = num_workers + self.image_size = 384 + self.n_classes = n_classes + self.target_number = 9 + self.mean_bag_length = 10 + self.var_bag_length = 2 + self.num_bags_train = 200 + self.num_bags_test = 50 + self.seed = 1 + + self.cache = True + + + def setup(self, stage: Optional[str] = None) -> None: + # if self.n_classes == 2: + # if stage in (None, 'fit'): + # dataset = HDF5Dataset(self.data_root, mode='train', n_classes=self.n_classes) + # a = int(len(dataset)* 0.8) + # b = int(len(dataset) - a) + # self.train_data, self.valid_data = random_split(dataset, [a, b]) + + # if stage in (None, 'test'): + # self.test_data = HDF5Dataset(self.data_root, mode='test', n_classes=self.n_classes) + # else: + home = Path.cwd().parts[1] + # self.label_path = f'{home}/ylan/DeepGraft_project/code/split_debug.json' + # train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train_small.csv' + # test_csv = f'/{home}/ylan/DeepGraft_project/code/debug_test_small.csv' + + + if stage in (None, 'fit'): + dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes) + # print(len(dataset)) + a = int(len(dataset)* 0.8) + b = int(len(dataset) - a) + self.train_data, self.valid_data = random_split(dataset, [a, b]) + + if stage in (None, 'test'): + self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes) + + return super().setup(stage=stage) + + def train_dataloader(self) -> DataLoader: + return DataLoader(self.train_data, self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, + + def val_dataloader(self) -> DataLoader: + return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers) + + def test_dataloader(self) -> DataLoader: + return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers) \ No newline at end of file diff --git a/models/TransMIL.py b/models/TransMIL.py index 3cb4e52..ce40a26 100755 --- a/models/TransMIL.py +++ b/models/TransMIL.py @@ -47,7 +47,8 @@ class TransMIL(nn.Module): def __init__(self, n_classes): super(TransMIL, self).__init__() self.pos_layer = PPEG(dim=512) - self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU()) + self._fc1 = nn.Sequential(nn.Linear(512, 512), nn.ReLU()) + # self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU()) self.cls_token = nn.Parameter(torch.randn(1, 1, 512)) self.n_classes = n_classes self.layer1 = TransLayer(dim=512) @@ -56,11 +57,10 @@ class TransMIL(nn.Module): self._fc2 = nn.Linear(512, self.n_classes) - def forward(self, **kwargs): + def forward(self, **kwargs): #, **kwargs h = kwargs['data'].float() #[B, n, 1024] - - h = self._fc1(h) #[B, n, 512] + # h = self._fc1(h) #[B, n, 512] #---->pad H = h.shape[1] @@ -86,15 +86,19 @@ class TransMIL(nn.Module): h = self.norm(h)[:,0] #---->predict - logits = self._fc2(h) #[B, n_classes] + logits = self._fc2(torch.sigmoid(h)) #[B, n_classes] Y_hat = torch.argmax(logits, dim=1) Y_prob = F.softmax(logits, dim = 1) results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat} return results_dict if __name__ == "__main__": - data = torch.randn((1, 6000, 1024)).cuda() + data = torch.randn((1, 6000, 512)).cuda() model = TransMIL(n_classes=2).cuda() print(model.eval()) results_dict = model(data = data) print(results_dict) + logits = results_dict['logits'] + Y_prob = results_dict['Y_prob'] + Y_hat = results_dict['Y_hat'] + # print(F.sigmoid(logits)) diff --git a/models/__init__.py b/models/__init__.py index 497cee1..73aad9d 100755 --- a/models/__init__.py +++ b/models/__init__.py @@ -1 +1 @@ -from .model_interface import ModelInterface \ No newline at end of file +from .model_interface import ModelInterface diff --git a/models/__pycache__/TransMIL.cpython-39.pyc b/models/__pycache__/TransMIL.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e1ddff6d6f3cbadfd7f1c0c4686a7806f5896e8 GIT binary patch literal 3333 zcmYe~<>g{vU|?9no{-eQ%fRp$#6iX^3=9ko3=9m#=NK3mQW#Pga~Pr^G-EDP6cZza z&78}`#K6dq%M!&36=92FPhm)5%Hhc6jN)WuaA!ziPGM<bNMT83%HnEfj^cJ_NMTK3 zYhg%XgNpFDGo-MmaI`R_a6m<PQ~6Rjvzdxcr7|qwPvKg~xIkbbLzG}DZwhw`PcMWg zlq#Gmv_NDbLo*{IL#j}!U<z+9W0YvBaH=R+L^Op@k|9bgRWyZPk|9bwRWwCFk|9bW zg)x{xQ}89or+%7@w*>qui%W{~a~(@cQu9hO^YfBHGRPPd4j_K94g&*2Dnk@w3PTiA zDsvV~3S$~mJ3|^{3R4Pm3uhE-3QG!W3qurJI|B<t6nij(CfhBpkfOxAVxPpy)S{OR zObiUk%pjAX7{umeU|;~z#TE<<3?&TB3=0?+GB7gKFxD_NGuAM~GnFvcFx4=nFiJ6` zFxD`oFiC-f6BK~G>@^JWEGZ1Z44UkIFF6<(7+x}g2sQ==hLH4GGeKady(arD=9J9b zD2|lO-1v;t#FSgCAhI}$H?K4|J|{6RB{#7syEux!ATzHlKC>jXC=nD2#kV+$Qj0TF zN)vN#v8NQ}7vz_gXtLg7EiNrcExN@KAD@|*SrQ+AizP3=D7Oe??=6<Zl9D`4=36Yq zsX1x4xIhx|IUt`Gf!wx|@fK%%d~!}=adCY7N`_ws`WgATsrr>UiFx|v`9;~q1&PV2 z`rzR5_4Lut%}+_qDTed)3Mz|u85kHqDYjS)6n>0cj9iRNj7*GdP|U>0^s`DBYyddC z^<ZvH2DuKT1%yFtPyjiD61D^b11O{!7BJK>EMTl*T*$bPQJkTcu?Cbx82vPviuf5A z7>f897#K8}iUb%K7;Z6E6p4Vuz=SABggq_4s64SKMI2-zIIS>>FjWa5IRvUelLHiP zd5O8H@$t8~;^TAkQ%Z9{Y@Yb|!qUVXs0>qH-YrhB5Bxwu3-Wal$n0CZ(BMl>&B=+6 zM{<S`$Uz{t6oFC_D6oq`?%-e&;RTD}PSI*eDLRD>DJ?TY(=rP<Ewh4>*e#ZT09W@S zkjWS!3bGSqcX1Rr6_+s9Ff=nQU;?G&66P$H8payN6mT}FVN7A^WvgL`XN8NirZA<j z!Nu9&;>;;bDNJy2c33)R2PJVv5CO7KlkpZ?dQpC9LGepgP$+<+R+FQM2gH^G5#ZpH z2E`MbbADc#QOYfrf};GaTdW|`5JDM2BNv=bi@>1?k6v)91DR9|@*@KyAEN}L1Y;F1 zC@c^u3=}A!)CIzzBnol{IEBpur?4!B80K1*TGm>SKN(UOTA4saha^J{V+w@F1QJVQ zl4M|FsAaF^sNq-u4y}bu3z%wHYS?O+YM2+Y)N<A^EMTtTT*z3%DGo}pEFd<R#R_48 zQtkq_6y}ADF#WZRCG1%oHS7yGQ<xSq)-a|qr7)$i^fG}|1~X`~`W1nKvPc0GR3JyM zWCRC-Cf6<2;*7+C)LWb-paQHQzc}?4OIc=Wd66neDJb0(se?k5JuN2@RC3;8Ey*uR z&bY;#oLF*;vox=`urxKbDpivUoNbCgS@ssA$1TS2TZ~S(7~^j-I^W_<&MA&B$<I#B zi{eU4O)P=3Ig<18;z4|*6b??@piH0w%4DGY$iT?MsKzM4sKh7)iVq%ee84gqBRG|= zWGez?@FGc&i={vWI5B{XDUt<o!I2Cmz;O<?1QgXENk|T2<YOxWiD2X)P}&1Ub}`67 zusQ*pe;8AkS~#MZA$499BLA@7;s6&3zMei9*$0%SK_(V=V97q9jMEFM3Rp^57qFGE zE?`e#OkrHeRKmJ|V<E#rrW%HL&JxxITqWEKcv3+2GA?9Z2vW~m10wlK_~9l9Ah88Y zSQiMTFoIPILrkq<$l?Q=Cj!oFEPmi%EYbjl6)1&hib67$IY<B;mY}e?#aWPF3@zxv z8I3DAwXhUamnP=iV)Mz&OHC}g#S)b26B@-5pO$QRi!&gxC^0v+B(<nW7nFopixTrv z@`_ABY|cDL6_HvDF1c>8<$x*>Lnv)jBm=S<6#utCX2CPA9Y`1)iQwo2$0W#s`k>?j z5(i~m9!4o9F-A5<9x$vDghdA;8-wx!C{{rjq#YD5-XJ3+7(o?93Zo>$0)~YQwM;ck zS&S)6k_@%XB}^sE3s_2+7qHeagQ|~aMobYV22kd&VX0viVW?r2V5ns)Vas9%SL=}M z&rrkG$^<D2(wHR~(m|P&xt9rQ4~ry2Eqe_^7RLh48uk>{6xM~z3%C|C)N+7L;izE( zXKILgh6UU;ObZ!nnQ9mo@YH|`57vduU>+}AoDal;*p|WqDs;f|{3UD)1VAzi8EZL9 z*s=s`I6>{2UZz^E5}_<%IFGxAvxd8dD}`+ib1hE|52$AFyTy`{Sds`SMZl>A6p%0f z|NsAAllc}~PJVi3N%1YV$oPVy{G?l~k?|RcC7L3)SkrRy6HAITLGjL<nOAa)DX-ub zOLA&v&MlVW!lDvz>L}6!MJ+gOu%%WMB<7{uVo5GdNi4DliGvzVMW8k+xQ;Gz1R2F% zoSB}RpP6!tEwLy)H?iUtquVX^;{3Fd+{6k^UT|@Ai!Hl6u_(Rx7Gp+{0Rsa=6jQuM z6jOY76jx$ON_<XgUV2GJkrBu+E_fMLT;u|>(G5hnf~0tgQj1G-N{ZuCGLuWNl~p#N zl%oVnOH4d0T#Q<bJd7I5T#O=&Ld;yue9T<TV$1@Je2jdIJd9NW@Wco$!-~v67J*YU zm;fc>B5=9}<vviLKuahd-Xf3)sA-X#n3)$JugMIKh|B_U1_nq51{Lg@j76YKqRCj~ z1Tq-x7mzoKKpxVR1c&u44saQ+mzP%r%E(c|Fs@!&X<jm@RhXD_iwnUn3IK%zYhGz? zL1mFAC>#azAdTsGaGe0s0cjet`GVUz;5-d3P{D476wkIG_xOSckjrkdf?Ix|iW!t7 zSW?RpbBaPivZzi*3UF}vfP&^0hYh4nWd}-s#h@aKgOP_(h>?SlhgnEeNKnX-0~+95 Ij2vJH0J_=bi~s-t literal 0 HcmV?d00001 diff --git a/models/__pycache__/__init__.cpython-39.pyc b/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45428bd7fc4c40dc9fb1116df4a96aea2fb568aa GIT binary patch literal 193 zcmYe~<>g{vU|@KnZk(jUz`*br#6iYP3=9ko3=9m#0t^fcDGVu$ISjdsQH+cXDNMl( zn#?a585kHe8E^6V=BK3Qc;=O)7NsR7r}}9!-Qv#$3B_l^#EL*htz;--VPJp|zYO#< z@^e%5D{~U_^vm;$vWp86lT-CWiW2jReLa2j!TO8!<Kr{)GE3s)^$IF)aoFVMr<CTT L+JUV7400y`&cZL_ literal 0 HcmV?d00001 diff --git a/models/__pycache__/model_interface.cpython-39.pyc b/models/__pycache__/model_interface.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e81c0ccdd33e6fe68a5ffe1b602990134944c80 GIT binary patch literal 10240 zcmYe~<>g{vU|{f<PE7h_#lY|w#6iX^3=9ko3=9m#-xwGeQW#Pga~Pr^G-DJKn9m%= z45nG4Sim%E6f2l!i(&`U98sJp3@J=GT)Es)+>9W-%sD){yit5$HcJkFu0WIkn9Z6a zm@5<|1ZK142<M7KiGbORIigXbV6((>#iPWbViLLHObm<+xsp**P!Z`UX|O7pDA^SD z6pkFZT=^(@upDQOLat(zBACsUqm-*0rOe3S&XB^L!qdW#!jmeSrP9nCrRvU*!kfa^ z!jQt3DxRg*%p9ej!Whh;$^Q}*B7T~Tw^#xaOENSWZ?P3r7UblYXfoasNG?iEEJ=;e zFDS{(&8$i-(qz2F4He7DFD}+(yd|7mlwVvNpPE-vlwVL8Uy@&xobi&8fq_Aj@fK%U zW=VW;ZemUj$kg2Yl++xM7Pg|);=I(7WRShcm>J3`PGMkRNM(p(Oks#(YG+7eOkqr6 zYT<}tPGL@AX<>+BNnuT4Yhj3D1%*QkLlj#IX9`yfLlis2$59**e@1bp@TUm0Fhp^s z2&M?NFhp^u2&ag&FhudRGq5m3@dh(!ir(V$1-Z&IuOzi7EipMY8I&TR&R`H>U|`^7 zU|`?|rH~?K28I%b62=;aW~K#93mF(0Y8VzULunQeUBX(!)XZ4J5YJY_T*DC0Uc;Eh zQNs|=QNxnLD9KR65YL&yl)~J?P{LKikj2r==)wS2%?;wEuw=6oT`1wn;sx_qQ`mZ$ zY8c{qOL!LWmGETogGJdvq9r_80$?6TFJlQ$mLQnT36&8_;RN%!dYMXivV>81A_!hO z+d{@#*1QmqeJSA3M71r2B?WF*3QG#qrWBSGBzsa=QsA~g*y*ewQ%iWVM8R(0=mmv^ zSPg5II4EAgJP9z5H-!hxldNG$;ge*@lB!{dmrmiYVThMW5vXB^m#tw*5tL*|5z1zo zz*NMPB3#RsB2vqqB3jFlB38>;BDX+(Aw!lzint_0Emw*}4cBaj6v?^FwcI6&HS8(; zDIC2_wTv~43zQZzxG==>)bgZ=)$qhCPhc$EQlhdzwT3ZEt%e~>b|F(OZw+sjdKPz< zMv7F5bc#$b6C*<nZ<b~j_W~_&jH)cqu3@a<T?mQ`oe7M2IZ&6$f?ZZuqMN0cB9|iH z%UH`-qF=+ez+fRmGh+%vib5-sBttEKiAt7s4Syd)tw4<chHID#ok~FF3zVoVFia6& z$XH^OqS(v`av>-bK<+D1Szx@7VIgY`L%dGj2dLdjV7q^n=w_Kf?G`N2uMvdXt&C>3 zP>m24yZ3?Z7DBRH1!}h{By3RZE`i#umd!MQxhOA1tX3GLzeFX=G(|l{BSkYstCty~ zqeLakEQKY7B}KcJ2`sOXqJvZ3C<VK%oFE^kFa$Gb>iXSc&dE<t29+qVe96SXz`)AD zzz__|mtqVI4CxFt46%H*ppa)sVVuJ>k*Sa+m|-PT5lELN(=FEI{L;LVTkOgCMMbH} zB`X<LGTma*Gq}Z=StP)~zyKzGnd@id=cejc<|O9nm**E{7Z)TZr|O3kCFT|Tdiv-? zN-qc%p9w23^a?6(aoOZ#7MCRF=N8y$GB7ZF2061z8?{`~i;vID%PfhH*R#pVPfpA! zw$nrCdI|E-EspZk%=C<s)D%soTiki2x$((4iN(dK#kbfK%TkLH(^GG;=4K`r<!dq) zNir}n++xj51o4=Pq(EK(2}QA#CFaC~irRvr)U?csTdXCi#U=5#SW*&862W2dl7WeV z;TA_yVsds;eqL%6OG;*5eh~))1H(&DK)eK%er`qV3=9mv81<`+(E|gdT0g(Eq@c7! z-!U;cBQ-DHNIy9vH95N=KQpgHub?C&9prluPOlbO=l-y&#a@%?7E5kwV%{z0;*ykG z9FVf#&>}>W@fKG>QEEw1VrE`y%FF-%|NqxyDguRIkpjqnP?e?zFaIK>UNW#SFcj&4 z<k^Zda|?1(UxI3lm!P8hB`8gRDyZz#qP*0c_~Oi}R8U|PmlS2Dq`ou))k)^JSc^*w zQj3Z}W%?}+SghO<DNZa)jnAkoNG&QzEK1BxElDjZzQvW3pO_L4j`<>x;kQ^omEJ8* zcy)M-GY=k3969;v@db$`8Mj!A67$kii{wGB<N;TRxv3>ZnaRbsSRF%y{GD$J_=5S) z`8heM$t9WjdAB$mlaot}5|b-$aXIH_q~>`i78E4jVs|qPPR=h%y~P%knw*%EbBi;e zC^b2=7^H(cxF9t-Gc7YYv!wDCzjJ<GS}91}H?gEBv*MOOGE5*o7b3!%oRgoIdW*L# zF()%69_An(a0-F)qIf_FCB8VLG%YPB^%gsnQN#!KNPcNad|GN^Noi4PaS=GhGAEa( z++r<AEGj7mr=eS1NMXlYlv-GtS(I8FpI(%ha*HW1?-rX+W?pJy(Jju@+@#c$_>|1t zTWpX-e~UdYzbH2`C;k>oQf6ZDE!L{kqWt(<T*0Y@rKx!(nTa{KIKa^mpH@=D2}<wk zAUCmqGTtqY<c!Rml%mwUTY`=$i3KH@WvPy3=>hrqIYueBSi&+>%Wts+rTT;xX@T^x zIp^n<8KvCf^i8aQsL>R;#ZsJ_lXi<0Y(<d;$b3tXTGs5uqV(ch?9lMJ#gbT*oDs!U zkXVwO0Zv{~ym|3CiNz)HNjdq+*~L-9iDl^p`T04Zbda1;lAoQLSA2^lBR@AaiW}ln zkW2<R&=PZRvE_gRqzLR8_T2oG(wx-dDAuCXVnYkCYr$+&gD4?rz{jUnloTZ<mt^Lp zuVlQ%84pQ(@$ujitO%3|icA?87(gkn_y#B!GV(EUFmf<*F|vV3W-(?iMj=KXMyCHP zOgxMnV497Q<v$A(68W8jSAdZVEC<pDQp*PBfytk2oC1t|jC{;IjBHG7jC@Qij4X^S zOe~CiOgxM{%sh--tO6iAm{=H@7@7WYvG6hSFp4m<fJ|WFVdP>2iE}WrG4e37{pMij zW8z^F0g3%*0qF&?7?ET_W;61ERYORyzVB7yIEvR~P-y`2I0%E<ZlG-I45}wqF)%RH zFvc*~GS{-yvevLHU|7gd%TmL<fH8$}A!99b4Z{Ma8s>$JwQMzP3z)%TY&8s7EH!K? zOudY?>^1Bu%#sYX95w7IERqbhoHgtztf2Ns4Z{M~8jc#4g-o^FC2TcZ&5X5NCF~18 zB8;^>H4IstH9V3GDQv6^k_;*A<_ydXHVkl44v=UKR}D`#6GSG3vxYkp)GFul0~gpU z86m|8TTXs@W=Ziaw#fK`qWq*=tda2<i6!8i2`;XS_(6#llyWt>Zn382<R_NgVlOT% zO--#zEs6xCWzNz(7(WT7_7-zyUdb(1a1nKjEwLy)H?ab2pC)fnB1jW!5-6$P;sjIi znJEz2TdX;WNvS!v7~^koKoe&1EuNy(;?kUw;`o%z<dULXkQI3#A|FJ69CnK_5|T_o zH4Qke-r@xp{+W5{@x>*n1q(sZBnOH(CN5BFW#VGuV&-82rC$za9u^K3E+#QXDW)nJ zw75qxBpFmPfy@A5HU<U;P~`?r03{5d27C=e79*(X9qU)iRKt+PRKq060BW9t5&)D3 zYPzx1FsCqQGZk5+Fx0Z<@su#vurxE)FiSJkvX-!<uq<G$VX9%RVX0wW$kNQn$WX`y z>5`Oyq84mExM|M>Zo;tpLE;G<MW7bRFL`j)u25ZTrI47MtWaI6qX2E1C{)+#SJ#r$ z+9?91c1QpVf`S|z%3uOqEQm5NFq{U7gW5|~rnrJIwIDw^BR(}R1zS5sldZ@GWP}xn z0F^OC1>i(dPy|XwY-yk(yhM`?oOGf@QKBUt(rPHa#h8o;j9bi!$;l8W!yS4{2(AzA zmP4R~0aC#LDtN#}8XFT2qZnhAB0+zk6bL-fh9fvVETB<=P{O)^t%hYGV=bs-$6mvY zCG$f>p_zXH$3li0Zdd_O!%@Ro!&Sqb!a0W}m_ZX-3>4*nf*dVVfvROh=3mKzp6MzW z7#LP^g5#*D5TqYe4ipuESl|)>lp%^rL0nK=-eSotNzGl!1CG6-GLUdNhyckIfs87u z1aYfCL^X&&b4?9M3=~O_2n8j&Tm0b0Mk1(%1J2=RK@kaRt}-xjF@b6&E=DeJJ;cGt z!3xUVa!gfnXi<u01WNt}C3$%MKEepf-&v5s0)`rJ_Zl={(92ZIRKk$On8KRE)PgF` zT*9z`DTNIozL2q&rGz1iDTN&&1~r4Vh9QeNg`<Q8l=GXJQaHhotCy7#l;@dKxWN?x za}DbPwuKA}(dF4|7(i75NUnwzD$CQ$4AQ%RvxWsGUdvX(mBn4dwt#0L!vfxg3?+Ou z3|aimj4(Qdx0kt=y+mMvAh>%flqFolp2C^Jm%`u6yg+0jLoG*%zyi@4j)jaMle5G? zRaTbx0*MrX8jc!<EXjq;jBqnu7-EBJLBj=7HJp+RpbmE}XkY@u6Ov@8<*wlZm4Oo& zi{erkYI*W_N~CMJn;C1kq#0^?N@P-m7s%Fd*6`GD*Kk2f!3m6odMP5Hkb}7wI>aGW z!;mGH!kNu9fwAa74MUcEil{hfKt#SoVSyr;FDAkOax=KR6!*g@|3Li`aBmNktkOZf zJswaGKZQY(0n*pgWQ3%CPy$@ZR3ro{ml!J`Nf9Ilu5XG!h2V2gIt3Sk=D1TDyy(GJ z3_>gBTg)k$xum;vC1Vk|UIe?U6XX~$3rv7~2X^EOkT}SZWSO52(hW8cm-(O}zap~D z2bTtT&3^@Q0m%F)4p5gA)N2aSWWL2wP?Vn@pOjd1iz7cT9vsNGKrNf})S_F=iKWS! zOhpqxIwyfjE~Z<onRzLx6`*c=PHApl@huKem#3sCF=r*yE!M=my!?_VmgL;Tf+_)c z_eTNJ+d<SZMRg1e3>cvZY88Uh6UdF}p!9?>dO$>d%E8FM01ib+DQpbtU@8@Xiq%`p z$%!SmSc^*%le3FJwN4SJDN<wy@*Jorz6I_qf?G<SATf7PdSxm|xy9*{Sd!=l>VV#2 zPA(|D#hO=|TTqFpc5g8k<docEOUq0zElRz`T%1>Yi#;PXu_QOK;1*YUYDqj)2GrXt zPAvij<Snk$ih|Ul%-qzxl3QGfDJd{Dyve0SMX7lukca@c`yiD?J;+H7Ai^6&fa;lB zY(<HADfzjeo>FQGw0RgM0E%2t(+yU=gEQeRR!|eQ_!ehwVg<OhT6~MGB(W$x6+8ly zmzE0(nv}Ho<lI{<X_@KqMU609Ag$OY5Eq=hzyv5$+!8^m2EgeC)JAy*s!mKm4LwE) zMhQj{MhR9aMg<ljMwb6<ECP%wj2w(2j9iQo%zWT>0~aHx)gZwr#>~UW!&s$EME1u> zIH2wYxb*;X6{NKXYDG4K2OQHGgBeyb`e`y3fy%2QP#3ZYRMToQK`K?S0N77p0%XK3 z5l}}R9Gi$i_F_Rs28Lpg8K4FahC{LGz}?JQM0PW$h9!lmmkD>zhZYSTO(sY>1Q#)A zEeudPgyb}EOS%P=Oj<!iJBR=WDVP9<D<^nh1C%yJL4gTvbbz`QjG#sbq+86v%*80i zD92PKftD8Fno*h?pmsI9xiJTCb7KJ$a&rUJzXUfckeV9MW(5mq7NLfzhGhW@sM%1< z3TZ}wc;G%ITMcUsTMD}kLk%0Kk>J7*s};jk%U;V-!d}BZn<0gxhGRCvTxL)wq=p03 zlw82MkfDaHhTVptumn8Qzzyy0fU0^Pa8=I<>9=M<cwC68zUT(DqUS}b=-E@a7x1B0 z^o6TZc)-;>sKEd>AKcyHsbN^aU&9P{J8y|VmS7Dp%vCVgz-Vw|hOb0ufiSo+Ba$Ur z!v|{0qzLpfFA!S@8b23WAYQ|eB{7>JMQ|<?sDV%-v_P_kAIuV1AXvke4(hvuM%O`2 z42cvWX+$$Zn2csbQ7=3x;i@{dV1<5_A+|&VFPw0;9E!TJ)kvV+4z7_vsTW)$p|x;| z`UpBr8|E~OVvDfp{RB<dften~0nQ$YrO8#y3bqRLscYR~wiiudU|<L-nhtXB3<d^< zDh~bR+<1_Ss<<FS5PHeQWt!69dbns7NZD)<0k3<Dyg-#LE2y_qTm-J_=7Z!x-9bb> z3n_&_-P<B?MF^^(i{^sV%>xmjp^73<$y&4+#9aa+mV%6B2aO}AWu_NdgM^lWgxE^* z<3TPG1Xlu4e2|`SJjAWVMQcC`*Mf+3AYw9z0L_2iV#!XeEWX87mY7qTT8y=hTn|#a z0Yq#BTLdu@t<ign4-__#Iv!MMYJ!SG&{#MFq{?I#U=(2CU<Zxa2{Az`N)B*EDaI(q z1g%7wg&4V*c^Io?@uql`R0m3Npf(S<vIO-xYZ$T^vKX_NY8g`)YZ%KIiu6*L7C>io zYZw=>fJWRv-GzmWwM=;|HH@`PRZ=CaC2R}WQ&?*lXET8Ee+^SJQ!R5iL!JmD0|=II zq_8zJHZevr<S`k5=3_ZQBX~6o@r>Y(1iK%&&w2~9{ReU#mfQv!6NL09z<~&wp-4_F z$Sg_B0nL4bihr;;gaBn!aAQ3dmNct$(ZUBYeSobrr~-}nrsfpuRq1=?f!j9<zThEX z1-HzcR0Y4p+*F19A_bV7GiW*pEUKu<b&EX%Jl0+eZeZME$xF;ly~UE1pP6@ywYVTB zv*Z?gdTL2xNlDQy&dl6`{GyVa%%oeq5GH6wqckV=7FQT3ZMhZ|<rm%J&CDw<Nz6;m z1p8H!1)Nxm4uH}XS8`%*YEC@Z3RbWWAprodY>Uhp7#PAq=}HFN7~^8$V3Yxm|FSW1 z{AXfh`p*FxDF!vVB*FUO)9reQ5JIVGK{HUGI0Cgli$P7O5{3ngHH=wI3mLPRL8}44 zbDYee@!2eJhr5;~lcAQihS`N7Rt_|RTf<VuP^4AD4pC9V)(jf1W3yo>)T&`!z)=IP zOKKR?7%~}ZSeG#Nfks@Jt4=9+A_sy(q5^nAK%ulaGcR4CBqJ3vN>^NvpO*p>0mlG% zrie=c38WR}=PH0E6!f5RqyTE@XXfW67FFsgMCO+&Bq!!6B<2+7D`e&sl$IzY=2a?y zCK^yoDalvJFH0>dFUl-Q#bP(OSWw8!L$OE?ZUVTS<fqAbi#;>1xF9vT1l&6SSyGUi z44xwZO=hGQ7wrcXAD|E}+6iJE21PhqX+cV2NvbA0BoUkhiE-tZWTY0wgG{)^1s=YH zu-P*6z~vMxNXGCMb7E0Cdg?gNz`&3JN*$mC!@w%W$iv72o|OPK&M`)FQPUwGXpRn) z^5Wxfam7OuZhZVLp7{8}(!?By4A?(KwIKI5gNP0gu?a+M0TEk4#2yd<N~c92=iK6j zPMsvD=H$SpPe9Hq$^xkaXKm07At*f;gZhXZ%sllB%p8gw+#He|+8it#A{=@gg5X(1 z+{;*02rOd(FGFDsX3%6UiUqX^kcX~8Y!C)@IEq2;DPd@4SilHsDb;}1Re*<<i#CJ0 zoXka_J`gw%Zm||;Bo?H?Ivm$Qo&oy^>^G1xU}fc?*Z}1(21YK%DqheO8GM2oB_D!p z2e|@Vii6CkVax&zcQN)c)H2mD1v6+eR>i^-T~I36_F`~yg(mIf<oqIVW+}-B3xicC zfNIGUQ0fI`s<O<~as|-Lqe5{>QE4)0f)SCBL9@po1w|mgXflISBxqouptR%`bAD+F zT4KA!o|a!!o>-Jp4fi4&4`US%*o#mpO}3)Tpn$moBCdi6(41M(Z4m1khyZyMVlFuB zz`+KJpfZ@59E?0H;3(77xW!ytSzH8aJltZ+D=0bz(tiX*faaTSu@xldr6d*?fm#kx z?8T{xN%=*2;5Nf8cF<IDPG%BhDwQiYu>`c{3M2pu^de9TBZ}L%(jT_+st6QcQEa}I zKKaGPMW6wUTim53nK{LJU=lLG03Lhb0MF;><>eIxGB7Yi3B$O0X{CAK+2+KYTU<zN zMz<ncP!<;~s01sG&jGDN$ji)2zr|FLb4#cUJmDN)0$K?H8skqbg3R`EgH=Oh!IQF4 zLP#Qd;Igy02($q27QYY7NJtF_&N#Q&pd3)xLc$vn@}NM!#bE=P6SV`)niT7SW^=#= h2Y5<|k%JL3;Rqr@Y#~MtW)RK8$ipO5&rr`00{|SWzf1rC literal 0 HcmV?d00001 diff --git a/models/__pycache__/vision_transformer.cpython-39.pyc b/models/__pycache__/vision_transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e278bdeba138e438387aa38b7a83ca7bb7819a7e GIT binary patch literal 11353 zcmYe~<>g{vU|{f3F;2Q@$iVOz#6iX^3=9ko3=9m#9~c-IQW&BbQW#U1au}l+Qy5d2 zbC`0OqnH^XVk}WCU^&)YwkS46h7^_*))s~+_7v3=t`zPT#wd;ywN%bjt`y!mOeuUR z{4K0e-0lo10x5zm3@L)CY*{?b%u&4V3@Jh>!YvFb!cY-DcZL*^6wwxj6j7)MzdJ*U zSc-THLy9<5M8KUPMIuGAg&{=}DkA94kRp{L-NKL}4HXeeQBRRgk!xX$5>C-bkxx-* zVT=+<(M(ZHQEFj~5_M-tQBF~5VMtMdsuxRP3TDvMs%qo%%`Yy=sZ>bLFR0WlNGvW% zRY)t!&s8YN%*|EE$xJFrEUMJw$|xx*D7MnqPtPpLC{5B!&d=2^D$h*MD9KIC(=Vtj z$uCOI(9O(EOi$I#%}+_qDb`QQ$xqVH1({c*4>C(1DqfaZoSB~&Us9BqSDcn#l$%<l zS5V3I5)>zXnvA#D3lfV;G81!>SwSLD%)!9G0K(3oh`I)fD8>|~7KU2Z8payt6y_Qx z5e5l{Y=&Zu8pawH2w$8bo2AICgdvMDg(ZbKg|(N7k)ecX0rNtJ{uV}t8rB--8b%3* z8YXFmX2x2!8rB*XBtCl$!vc^wY$6OPY!VD9%r)!~)<Pymh8l(itP2^yG+PZ*3Ht($ z6!wLTA`A-|!D5^>j1X29R}FIwO9}_b<{E}@1}26S1{MYuhGu3)hF}ImhT;<j42)nH z$-u~v!Vt`$$?2EO2=y|E0<qZ{7#M;<-j)G-yM$o@V+}(I!vdxhriF~s3=0{>8B&;} zL7JKUo`W<oFnq9wG8h~*nQyV=CYEH}VooheyTwvmSX83Pc#E-OB|{M>0|Ub^EB%c8 z+*JL_oWwl+^8BLg;)2BFRQ(W8-1vI>;El3d9C`Ugx$(&<X|f>aF@fx6s?vzh%a2bl zN=%6_DJsoNjt8kp%!${t$;nSn%qh0h0||gMy##5f%F0bm%u~oLR>;jSN>wPyNX%0( zQYbD-fh57`L>+~sScS|yg}gky%)HDJJyb&#LNZbnQZkE6iZYW*OEUBG6!Oy)$`W%* zQ;QXH6Dt*xQWY}ulJkp-Qj<&cG#PJk6_+F?XXm7rrRKZ@g>4Zi^J_BQV$RLXyTzQF zSfR;%i=#ZTC@(WFz4#Ujh;@q<oV0JTL%ns2qckrQlx*W~v8NWLW#*N|-(txv&517p zrPm@*Y~5l>Oi78q#g?3tm|GC9$#ILVBsH%%zvvbV$g{VYi%U{&F(%$(Oe*4LU|_h# zm~)G<6dW560#sbx62l&hQlO9qB?Ja$4i*JQHYTQjY%E+X0*nGIGAta-B8*(je2hhk z3=9mg6u?l#jKn(t_AyG<0y&$Ffq?;>wFDR#7)lsw7-|@6m_UiQnKhVUC8M7v#7s@r zB9H@$q!}0(ia@R?0>ylh9Eb}_O-SzNMTsaokh4MY&A?Qo%D}+jrpa`RGo>iMAikg| zKj{`nNl{`ZD9n>VZUv<y5C#P$++isU3=B1lDU8`n#XL1k#R4@9!3-%(pd4GvT*HvX zkism<kj9w8QUl8GOcD$t4B`y6EH%s}Oj*n|AW;xzO=0b276+9YtP9wnIhehMc_CvB zs|Z6aTMb(<gC?6_krK#1d>{gpMl>0VKx|E>TdXN1l?AD{*iur<GLus^Ic~8QXCxM+ z-eSp1$;>Sh2E`>yQDR<7kp@VwCWz0LmXn`f6n~33C9_PEy$Do<6{&$_)fpHVqByft zQz1c6qy^&hf^_HS#zT_}TYhOtL1_sjRKej3P8!gFODxH70tG85$`}}V7+DxO7^N5` z7&#dE7>f)T7#Na4fryMji5<i*23eNM5XG3n5XF?j7{#2z-VO>p<`k9|&M1}?))clD zhA7r{1{Q`WP>^uk;&3U-F9=91$*770u@njtOEQWTG=fX=lQR;FOEQxcTv7{4GBgzm zQi~Ld6LSl4QWX?5$}>{)6cP&xaxzm>KzTSfF*6UW6_iVhQj0TFN)vMwl5+Bsvx_zL zxD*r=6#SCGsRZOo1_lNY8<ZkIxL5+5G)ow37@C<DFo9!_v4$a@8B}f-sW31wXtLa5 zEiNrcEh;hu1wBW6JgCHoj|T_1CetmJ;?$fpFlQy>EzbD(<ebFf;`n$-@Pb0N$e4kF zApjJn;E-bEW2_Pcc^BkyJ(y`24g{%1&O3~tBv`|c#nj9k%m8s1GdMoLRu_RHekD_p z2`CYOlLI(n!34++x7gG2i^>y=QbIuP00jU8W0e5h6;Me{M!%Kpw>aYC^AdAY<Ku5} z#mDF7r<CS^*gWy^g{6r(P#O04_>}zQ_;|42iY!0@!wdC#a%xTv%==*bO+jt~IRYGj zz90vH@*f8y3lkp;SPXZP0wp*^l1gDrWzJ$rVM=2`N?5GWgvAC<SnR<Jn(VijeRB#R z87>)Q6*S%%Kx|OBf+OA!9P!Nz3mCzX&Rif+!&u8)!&D$q!&u7%s?tHF)k3BkhIp15 zmS%<;hIm%6C{qnHSd^`VX#sl;>q1a+;Xou8&LSg_he45kizzSf7Mo9IUTR{|E#|aj z!&}UW$t56`(JgjR;OCc?++qP0Sek6$sJq3TnHQgynpjd=lv;d?KO-|GB^614C%?1= zA<dbXToRv?SeaT><OZ@49M0en2M0FD0g<4<1{E_5j9iQy;5-b<$5lMOIR&_r9w>Og zNzV*aIx;L^0JS`78EZgI1jdDoU=|Zxj2X^?i7~<wC`%EjC@BJ&1#zMa$erdO0_;vO z0d_1X1tozT%L9rlMlLw4;s!YtmOeR(tU#KrL4*y60JX4hvADbXgd$R65vW>#WIwP4 z;BW!kV$Z<95CO6Uln6K&B{+*fA{eO;RCR%zQ4A_l;i-=koaR7<7$>M`OW|l?h++Y! zIo4nXO|Dy<jwL0jd7#$9OHc-MgQh=FbptXCghApU!$ARDoQ9hIK<SJLl;WfqYT0X8 z3Y2QtQy7IosgAjXB@5IbWlUibVE{Etdzot(;@L`A5UH<(WdT?PI9+mrW!S*!k_(hB zHCg=9LDn-cJlOvdWcNzOTP#VLiN%`SNNJTbuQWG4BQ-Il_!et%a$-(u5h!qr96@=T zxiGt|$Q8umOe`tMiw7myTPy`d`B}F(K_pl}lLL~5nL*W7AV@JsVRl(O$f8>uh1u~C zLqP>lQ5aYjWLZ2&F*t_6aRrVhP|D2#MH8eB;$mb2rCcU4MjlYvDFhE^>?s&z1Sp_E zsjhejC<QZSF~l&}GSxEIvcS^t0wz#HBn4EcLh5m5DTalt3s_Q^QkYXfEqm65thKB) ztf2b8hBbvnlA(q*g-McunSqI+maUe(mZOHPhJ68R3R?>MLM8!*8ip)35r$gM8qO^C z6!sL3UdCF^8ioZNHJl4UEY2DZ0fq&vDJ&^W3z;Bx*RU*P2InMjyB^Hqg61T~8qPFO zCyLXr2vkTFf$~HVsLcRwF}wsN-k1OW|NsB;|NsC0S27lbgK8;ma9v*nYSQ0gFG>YB zA&S7Uz+RAAlv`SodW#d>H!8?4PAv)oMF4woep*RxVnq=+PC`NX2^2A!T;O;Fw<>Ni zI^AOQyT#}XF_^LN7Gw4;#<E*1pk#oQ$3TG&t_I3M@g@RFZHzpOdW<sQSmR=n0%zzd zK}6iavNUTEs78S17H}I$lMOwWgAxSTblh3nhk=118)P>qbT}CKxQjp{7+D*XFhGtg z29+N0tj!3Hc9s;D6c$hglggaMn#O^Y-Puw&L6O?R7{w0G@EqU_&*}tfVn7?j0uE3w zB!e2?AS*x^6muXeL7A?&05!vN)i5@Lst%?SW>B`TVX0xQVXt9JVeDlCS2-LhOxY|& zRwb-89L<b*Y$a?9*cXC|1CAP)PR<f8P(M`!)YIfH;jUq<;jZDVVM}4|WvO9^=YeK@ zO;&I+C;}CHMW6}+65XJfh6FcDkvk~rJwSvPhyc|ln!HFw11qR=We5(@qD+tqP)i70 zJlx{&OaT>oC6(Z|$Str+qg%|Gc_rYo^%iq(PJt#5Bte1FDcCKyICFCf;)@bXGV_s~ z1+o_6Hc+i!1aitPE|ATTN-l~IqzPIj;wUxRK`B586!J`<DujuV=RXHH?eQ>H@qwcY zQD%VR1f&v#L4`6Xp1@@WC?IM;wH~y-TgVs#DLa^oWNKJ~L7f<e1)xR%qc}q?V-2WY zWU66Uzzpi!gL3dfMo{M&Bm=7lSs|eYD$~HhR0K*2MR}lX%L)k^kma`oic(8Ti}K<V zVHL?O#!84zaM=hZ;2CWa$g80A#lR@S$im3N$j2zfSj7wRBrJ0=x@mHvXC`o|?hVTS zoIaq&gI|78E-{s776SuA8^}_STRA}a`9B|b5l9pxmx1yiC|EIj1YF=;#g@Vj?h$aL zaH8}GxNmU<B$gy+xaKCMrc}u(c!I`#6iV_Hz@iFZ;grn0ba0#g<pTx=hL@l!0$fVI zWMF1sK<R;kT6-W2is)ic%cFz=-r8fVVUhwxB$E__2&m%*&WJUjI+(eJ6_ou-ShLtb zJ=Gef8m1IRaE4+BXDB8=O{QDi*{MZ&sX6h*nN_K`*g)e1DXE$qNI8ZhGdCR~%T)mK z5roYRs$#+X)M7~XVRO#UD>F(d0u{c{HegW=D2LU8B8Ve1FFrXVF|YU*XDTR2z?Eek zNE%CxIS&*!kOaZSD8R@Df=pGy;GhGC93tt1+y9^lVqk!`iNQ%9)Pt)5RhKM~>avy< zQDLSqE@WK5oWhjCxR42wxHXymAaxU{g{jE`senK}xy7EA15U(6iJ&wNY1@M9lOj-B zDM|%-j?v>5V>q@dWI4zqpon8&<YSaztP(=<0xV&(6oD$5q6$!0fa}Og5X%xofU^mh z00$B{!;~>FFwBM-3+{Mu7J<Y-u7QRFgaY{l#40XfU|>Mh4^hl1OzDg%%%HJZP?Zm= z4N^Ek<FhH;Dcq?nS*&S%;FQgi!rQ_b#ReXy1&!f?M`@!tAY-#poRAUOC@#o2Y!r8j zbc$q(R10Gi54c|84Q9}kxg{6|9?1*=4VHmMgi?#DL=+%m3P^$q=s^H(A%Mc+B`6gX zfoeloox~6F6BG+FFff4XC=f2@VFV4#z_LXLGYhDO0cV2}7SOneBm=0O)WMv>1kM7i zH4Is-S!}fepwY<^&IMd4%qh$r%nLzd9Srf@2+<mW6lQS-xC~DTdlqj8QwmEjD5LQ~ zdTb?}3;0tQYj|te7BagqG&6QGmas1nXlH0=Y-eg`Zf9v{ZD(s|Z|6t@RlaO393_G^ zY|V@ymEfkWPzP%VGq?s7u3<@G%w{RNR>HnOq=UJJWg$}@TM7FDQE+`JRw7=*kj0b2 zk-`ZYF$MFuVZ0iK1rjw3S(2cURdDv^@&mWDG+7YYu&4o)tQtWCD0zU!ISWvRQ$g7Y zROEr$Ud5ocPdY;lLo5#{lPq9Z$S{$qkR=#0E~?3RizTxpHCL1A7L%R<Bu9XgDwqHz z)mvOPIhn;J$@#ejc6%8Z7(Rnytx69P{YbG7&&-GsP=rPd*YPqiFn|V<zy&hcbxO62 zB@7+R9c&#e9UL9(9ZVez9gH=Mk_;WJDGa@w5I1WwYI1^Gm*85vs1Otch-L&@eO&~~ z>WttDL6fPd6P(2}S90Fs12q-F9RyH~pBlxSQcwgPr7KD-Ni6~gEZAs}2@spW%@nXr z@PH}mVqjp{2g<{sU}9ro`pd+~#mMvTd9}zo_lH$e>{l{^#||~6kcwd*P=O2UPeW`1 zW&2y);6gYaT&Na-3sg|8c#AV2u_!S&wIsC&GVI1$m0FZve2X(Vr#QYOKRY$=7H2_z zF;pA4JH!FvgL^|9IhlFjQSMt@zThzvpUmQtTdYNidFiRQ*dU`(w^%^U8E}`XC>In! zETAqBc<AyLD`-Tk@)mDAcyy^e6*LD@tSJa-9>A*va9X^@4e}OvoGG;!?5kU>DWHKe zNZ^6P5*&EoVC)CQS|4Z}kPqb0_=3ce3~)>ofx`P1TXuP3QF<{X6u>2BFGvF@kio5j z^`IaJje{{Tb20KTL7G$?Ok#{ujC@QIjB+e&f4Eq=nB|!GK%Fxw+zA!;2p7m5B@8u; zC5%~2%}hm3payXkcuowQVG%NUHH=xTDGa5IMR_HRS!@f~!F*=W<Vl_m)C7)ZCQMnl zYL*lhkZN#NXZ0(Z0P^ZY5CJYcK`jMM&RblW#hH1<C5d^-NPS>PEU`f%rU*2gcZ&rS zRkt{k^TCQs;)_Zd7#N^~aJLw9(d)Ja3=9lsK%ogLML<KVe2ij@LX13&0*qDiSi%it zdOXCw;$%?y0xAkX7*ukD@&&lcc3}jKU}rI;FiJ9{FiL<%rx~&sk@%n)DAsJIA`?(a z4yGBh7=uAQR>l;D6oyu&6s9yLNzgboXp9>o!Uz%vi-Lj}JPKaR0csv)F=a7Jf%>_i z(QycytA;a$MVz6A3q&$8)N<Bxfd(wHK>cI(6vh$`P(1+}CTm7v*KjOkS-_pbl){?A zn8LP@rG^uxPMV>H3&Bof0#&xX%(d(#95tNHj5X|83@MzF4B1Q*7>oEyIBK|{V$2Y+ zVj(YvCdL|eu-R-N(-*QV-~mm#)^IEYnZldJSHre|A2iKX!=A>J!kEI<3+ioi`+>XH z;Nrh1mVtpGeb!76*l7=LUxHgYx7d?1lS`8_lQo%c@qouO;?ok7OTcq5`6;PIF`%Rr z2Py~!Ag!V4Ac4ss0@R>`wE4KwO7oII<3x!$pz*$v)S`m?98gIZ4^o#8B0&8S@SIOk z0=Sqc%S<hY6iLj<i6yr<OY@2gOH)&;QZ>0D1q7%~ev7gE7Gp*fTOPP@D*{bzMDc@5 ziFi;c1upEO_`#A00dVnti>ch;7E^`+T8h0Tk(mc~9YjlBaz3a5eif9A!G!`NAEN-c zwJF3X1}cvkd6<P5*%<X0r5Kt1va#xbIwe||i5r)}D1`)Q=nT{j1gCRQe-(Sv6f|~{ z#k7DK)KE`hT*zF@21<W5Y&DE&OrVZ^FSxl{!?1u2G>coqx)3y11M0Q0gL-Sh44N!{ z;P{3_KREU{Ah8FIxM`pmU`wqi08PPwD$$}1peWh~B4&W1njH}Z;O6Zuro7~$O(5yb zpi-9$UV9Xyhc912QEEYAQ7WWzDSizKRUS|dXXIn#VisT&Visbmf;Rn->wZ*iC}9RF z!ax`tW}s13a8DLGYMaIA0&16l#=aPv8EY8}1!_P&S%@sGG++ef`(Or5CTL5Dxo9WI zFLOY|T#$d65&pf!oRpJ|<}2_(%@2@QKy?Uc;2G2v=VGjq#7qXz<~)iIKqVb0{eyh~ zYTF}x0Gg%AVgj`RV@+b1KwXp?(C7(hQiKsy_GL2^S%Bu>!G03~u^{o!!~mKZFBAb6 z`z(-r9-_%uR1ESC3)p|RI8*aVb5p@B5|CGKG3TV_X|f=^TGRsyfL$O0RQ5tV4f5eF zvGi2XIDBzQJZL5e-0uP9?*GU^AjBxd$i-Zxff)qYHG|V4%CsdYBY-eCSU}YeB3RNH zYMGIO#wdoVmIX9;2^wIkWh?;~)y<4GETClx*^EUJHOvb@1EH*-<PU227E07G!-gu@ zz@bsJ8{|(=I27#xC;!BPg4Dc}B5?d8b?0s|<`wCKl;R8=aghJOxji>EB{Q)k6*8Dp z%)!XOPz<RExtLfOg%~*)L19#-i9Lj{n4k%rK;=Up;IIWb2-Kbf*T>*SOcAK=Edpm* zaAF2!j3Q9;y66B%G0q8tRt5%!wV-$u11%H*53=wuvi;+e<>2HH<KX3B=in>?sR4IH zQCi)gJP69lpn;ZR9q`Cs8bb|u!4ISxS;7dKLk4vwdzn%==CIc=E@A9rtYxYJ)kU0s zFToB2jkiE*pJ)aKhLGpIZ?`V^er$u?O2%8vsRhL=*^9u{3b+Zq2;>vcfG<)z9n`W~ z$pUdAsFwn+fWX1Bk{P0PF(`Uj!Aoc$NgABE!ReZ#EVCrOBr~s47!g-2j9iR-|5!ll z1DK0Wf@Xt34JS~M1;Y3|4H{TyfqI%Xg{_w<g&ob)9DbP%j0_A#U{8bnTXYEIr^6uP zD2O-)B94Oyuq(j?*r}W#rxxcX=Hy7>aO){*yR`_^(<aBQ;9yNkEKXI#;ndS0r-FhL z8G}kh5WiRtd0>JEJm>%!PGbi3X+d*RDWGLz$P?|nw>Vrp{ro*r6I0wmz{%+)Xw}n8 z21W*kWKhciWH1PW!V*;Jg5t553uS1agsFx#g|V3lRIY%QOBCspfJza@8g}rU6KmcS zD4%U1Q#xZUM-6k4Uk%3swi4zA>@{qlAyGyKD4$~?cu7+UCumAVf&pBdThwwwWm#(2 zYS<Psfy$*C=AubZQ5<FzW<cdZy4WGQN|>{_YB>5p0~w&vD{fHdA7okya~AIcz7pmv zmKydNaBA#@3|sKmFl6y(31kT_5CV;HF*4M!OkgZxO97Re{8<9=!r*Fy6;f@0n}Kdc z{-Bye1kuwh0*^z23mwpe8DxML)LXs9?F1VC0}ZkorWAongCfvC)Ge;y)WTBG$a!K; z(R@&f29;2^SW8PXbBb?qLt6ddMNYT4auSO{^-E=H(Jc<BMEWh3l*E$6TdZlBIXUsS zc#BdCOEZg7i@__Tpz}Ak*fR4#1BSQQL9@mwnYmGHrNyc7NqJHHpk8P^lCkW0;A!OI zTU?0Y^(eli{F0KK)V$Q>Y_KtGP&JJ4Mc{5CJokZ1SWQr31}#v5RQMc>0!)02EQ|u+ z_Ms4%B>?X1@Gwe&+ly6#urvuD%0OubfQl<nQvh7wK}K)kJwykbJwyj^j}X*Dbb$By zSX00~L~ty!p!fMSIf|Bmf*%oepoK(5OF?3wTDE8zhy@ywELs6#fqe}oK$!(xuQ-6b z4$em42}&`>DlvrDP&#oSKY{WY2!q(5ObzxQc-$2<<p6Gqf%?H&%%I^naPtK+O~?YO ztwBRdNDWp{1*6FXZo(IVram-TAqfxc8Bo1-ixU*pi8+~7sYRDT&c`+P?uPIK7pSQM zS{wo`GW4L0G)-1$;ZbxJWYRfM4q!*Gd%-hxprj0L1>&Br1NDP7Kz4$fG#re4;v8&6 zATdoA@SGHQaaoZv$R1D{&}0IOx@j_kD_S&4ya-(5g97swC%A8wpPy4)1PWkCpNRv! z)J`ui4^sUpf{X>l@-5c1oczR+TP#WW`8nW;nOkg-u^aHJ6z~KVWYP;X%K;uRhE%1X zimwP%qZEP4YOuqRDo=RoxW!=uSw3S2YUmUzF)%Q&uy8Q)FmW*QuyCkwu?TSsaSI6x PNehVz2??cfv4bE0KQ+Zt literal 0 HcmV?d00001 diff --git a/models/model_interface.py b/models/model_interface.py index c7bb723..1b0f6e1 100755 --- a/models/model_interface.py +++ b/models/model_interface.py @@ -4,6 +4,9 @@ import inspect import importlib import random import pandas as pd +import seaborn as sns +from pathlib import Path +from matplotlib import pyplot as plt #----> from MyOptimizer import create_optimizer @@ -18,9 +21,11 @@ import torchmetrics #----> import pytorch_lightning as pl +from .vision_transformer import vit_small +from torchvision import models +from torchvision.models import resnet - -class ModelInterface(pl.LightningModule): +class ModelInterface(pl.LightningModule): #---->init def __init__(self, model, loss, optimizer, **kargs): @@ -37,11 +42,11 @@ class ModelInterface(pl.LightningModule): #---->Metrics if self.n_classes > 2: - self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'macro') + self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted') metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes, average='micro'), torchmetrics.CohenKappa(num_classes = self.n_classes), - torchmetrics.F1(num_classes = self.n_classes, + torchmetrics.F1Score(num_classes = self.n_classes, average = 'macro'), torchmetrics.Recall(average = 'macro', num_classes = self.n_classes), @@ -49,17 +54,19 @@ class ModelInterface(pl.LightningModule): num_classes = self.n_classes), torchmetrics.Specificity(average = 'macro', num_classes = self.n_classes)]) + else : - self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'macro') + self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'weighted') metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2, average = 'micro'), torchmetrics.CohenKappa(num_classes = 2), - torchmetrics.F1(num_classes = 2, + torchmetrics.F1Score(num_classes = 2, average = 'macro'), torchmetrics.Recall(average = 'macro', num_classes = 2), torchmetrics.Precision(average = 'macro', num_classes = 2)]) + self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes) self.valid_metrics = metrics.clone(prefix = 'val_') self.test_metrics = metrics.clone(prefix = 'test_') @@ -67,18 +74,103 @@ class ModelInterface(pl.LightningModule): self.shuffle = kargs['data'].data_shuffle self.count = 0 + self.out_features = 512 + if kargs['backbone'] == 'dino': + #---> dino feature extractor + arch = 'vit_small' + patch_size = 16 + n_last_blocks = 4 + # num_labels = 1000 + avgpool_patchtokens = False + home = Path.cwd().parts[1] + + weight_path = f'/{home}/ylan/workspace/dino/output/Aachen_2/checkpoint.pth' + model = vit_small(patch_size, num_classes=0) + # model.eval() + # set_parameter_requires_grad(model, feature_extracting) + for param in model.parameters(): + param.requires_grad = False + # print(model.embed_dim) + # embed_dim = model.embed_dim * (n_last_blocks + int(avgpool_patchtokens)) + # model.eval() + # print(embed_dim) + linear = nn.Linear(model.embed_dim, self.out_features) + linear.weight.data.normal_(mean=0.0, std=0.01) + linear.bias.data.zero_() + + self.model_ft = nn.Sequential( + model, + linear, + ) + elif kargs['backbone'] == 'resnet18': + resnet18 = models.resnet18(pretrained=True) + modules = list(resnet18.children())[:-1] + # model_ft.fc = nn.Linear(512, out_features) + + res18 = nn.Sequential( + *modules, + + ) + for param in res18.parameters(): + param.requires_grad = False + self.model_ft = nn.Sequential( + res18, + nn.AdaptiveAvgPool2d(1), + View((-1, 512)), + nn.Linear(512, self.out_features), + nn.ReLU(), + ) + elif kargs['backbone'] == 'resnet50': + + resnet50 = models.resnet50(pretrained=True) + # model_ft.fc = nn.Linear(1024, out_features) + modules = list(resnet50.children())[:-3] + res50 = nn.Sequential( + *modules, + ) + for param in res50.parameters(): + param.requires_grad = False + self.model_ft = nn.Sequential( + res50, + nn.AdaptiveAvgPool2d(1), + View((-1, 1024)), + nn.Linear(1024, self.out_features), + nn.ReLU() + ) + elif kargs['backbone'] == 'simple': #mil-ab attention + feature_extracting = False + self.model_ft = nn.Sequential( + nn.Conv2d(3, 20, kernel_size=5), + nn.ReLU(), + nn.MaxPool2d(2, stride=2), + nn.Conv2d(20, 50, kernel_size=5), + nn.ReLU(), + nn.MaxPool2d(2, stride=2), + View((-1, 1024)), + nn.Linear(1024, self.out_features), + nn.ReLU(), + ) #---->remove v_num - def get_progress_bar_dict(self): - # don't show the version number - items = super().get_progress_bar_dict() - items.pop("v_num", None) - return items + # def get_progress_bar_dict(self): + # # don't show the version number + # items = super().get_progress_bar_dict() + # items.pop("v_num", None) + # return items def training_step(self, batch, batch_idx): #---->inference - data, label = batch - results_dict = self.model(data=data, label=label) + data, label, _ = batch + label = label.float() + data = data.squeeze(0) + # print(data.shape) + features = self.model_ft(data) + + features = features.unsqueeze(0) + # print(features.shape) + # features = features.squeeze() + results_dict = self.model(data=features) + # results_dict = self.model(data=data, label=label) logits = results_dict['logits'] Y_prob = results_dict['Y_prob'] Y_hat = results_dict['Y_hat'] @@ -87,8 +179,13 @@ class ModelInterface(pl.LightningModule): loss = self.loss(logits, label) #---->acc log + # print(label) Y_hat = int(Y_hat) - Y = int(label) + # if self.n_classes == 2: + # Y = int(label[0][1]) + # else: + Y = torch.argmax(label) + # Y = int(label[0]) self.data[Y]["count"] += 1 self.data[Y]["correct"] += (Y_hat == Y) @@ -106,19 +203,28 @@ class ModelInterface(pl.LightningModule): self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] def validation_step(self, batch, batch_idx): - data, label = batch - results_dict = self.model(data=data, label=label) + + data, label, _ = batch + + label = label.float() + data = data.squeeze(0) + features = self.model_ft(data) + features = features.unsqueeze(0) + + results_dict = self.model(data=features) logits = results_dict['logits'] Y_prob = results_dict['Y_prob'] Y_hat = results_dict['Y_hat'] #---->acc log - Y = int(label) + # Y = int(label[0][1]) + Y = torch.argmax(label) + self.data[Y]["count"] += 1 self.data[Y]["correct"] += (Y_hat.item() == Y) - return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label} + return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y} def validation_epoch_end(self, val_step_outputs): @@ -126,13 +232,26 @@ class ModelInterface(pl.LightningModule): probs = torch.cat([x['Y_prob'] for x in val_step_outputs], dim = 0) max_probs = torch.stack([x['Y_hat'] for x in val_step_outputs]) target = torch.stack([x['label'] for x in val_step_outputs], dim = 0) - #----> + # logits = logits.long() + # target = target.squeeze().long() + # logits = logits.squeeze(0) self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True) self.log('auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True) - self.log_dict(self.valid_metrics(max_probs.squeeze() , target.squeeze()), + + # print(max_probs.squeeze(0).shape) + # print(target.shape) + self.log_dict(self.valid_metrics(max_probs.squeeze() , target), on_epoch = True, logger = True) + #----> log confusion matrix + confmat = self.confusion_matrix(max_probs.squeeze(), target) + df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes)) + plt.figure() + fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure() + plt.close(fig_) + self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch) + #---->acc log for c in range(self.n_classes): count = self.data[c]["count"] @@ -156,18 +275,24 @@ class ModelInterface(pl.LightningModule): return [optimizer] def test_step(self, batch, batch_idx): - data, label = batch - results_dict = self.model(data=data, label=label) + + data, label, _ = batch + label = label.float() + data = data.squeeze(0) + features = self.model_ft(data) + features = features.unsqueeze(0) + + results_dict = self.model(data=features, label=label) logits = results_dict['logits'] Y_prob = results_dict['Y_prob'] Y_hat = results_dict['Y_hat'] #---->acc log - Y = int(label) + Y = torch.argmax(label) self.data[Y]["count"] += 1 self.data[Y]["correct"] += (Y_hat.item() == Y) - return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label} + return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y} def test_epoch_end(self, output_results): probs = torch.cat([x['Y_prob'] for x in output_results], dim = 0) @@ -176,12 +301,20 @@ class ModelInterface(pl.LightningModule): #----> auc = self.AUROC(probs, target.squeeze()) - metrics = self.test_metrics(max_probs.squeeze() , target.squeeze()) - metrics['auc'] = auc + metrics = self.test_metrics(max_probs.squeeze() , target) + + + # metrics = self.test_metrics(max_probs.squeeze() , torch.argmax(target.squeeze(), dim=1)) + metrics['test_auc'] = auc + + # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True) + + # print(max_probs.squeeze(0).shape) + # print(target.shape) + # self.log_dict(metrics, logger = True) for keys, values in metrics.items(): print(f'{keys} = {values}') metrics[keys] = values.cpu().numpy() - print() #---->acc log for c in range(self.n_classes): count = self.data[c]["count"] @@ -192,6 +325,16 @@ class ModelInterface(pl.LightningModule): acc = float(correct) / count print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count)) self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] + + confmat = self.confusion_matrix(max_probs.squeeze(), target) + df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes)) + plt.figure() + fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure() + # plt.close(fig_) + # self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch) + plt.savefig(f'{self.log_path}/cm_test') + plt.close(fig_) + #----> result = pd.DataFrame([metrics]) result.to_csv(self.log_path / 'result.csv') @@ -226,4 +369,18 @@ class ModelInterface(pl.LightningModule): if arg in inkeys: args1[arg] = getattr(self.hparams.model, arg) args1.update(other_args) - return Model(**args1) \ No newline at end of file + return Model(**args1) + +class View(nn.Module): + def __init__(self, shape): + super().__init__() + self.shape = shape + + def forward(self, input): + ''' + Reshapes the input according to the shape saved in the view data structure. + ''' + # batch_size = input.size(0) + # shape = (batch_size, *self.shape) + out = input.view(*self.shape) + return out \ No newline at end of file diff --git a/models/vision_transformer.py b/models/vision_transformer.py new file mode 100644 index 0000000..ebffe7a --- /dev/null +++ b/models/vision_transformer.py @@ -0,0 +1,330 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Mostly copy-paste from timm library. +https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +""" +import math +from functools import partial + +import torch +import torch.nn as nn + +# from utils import trunc_normal_ + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + +def drop_path(x, drop_prob: float = 0., training: bool = False): + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, return_attention=False): + y, attn = self.attn(self.norm1(x)) + if return_attention: + return attn + x = x + self.drop_path(y) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + num_patches = (img_size // patch_size) * (img_size // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer """ + def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): + super().__init__() + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = PatchEmbed( + img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Classifier head + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def interpolate_pos_encoding(self, x, w, h): + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + class_pos_embed = self.pos_embed[:, 0] + patch_pos_embed = self.pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_embed.patch_size + h0 = h // self.patch_embed.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode='bicubic', + ) + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def prepare_tokens(self, x): + B, nc, w, h = x.shape + x = self.patch_embed(x) # patch linear embedding + + # add the [CLS] token to the embed patch tokens + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # add positional encoding to each token + x = x + self.interpolate_pos_encoding(x, w, h) + + return self.pos_drop(x) + + def forward(self, x): + x = self.prepare_tokens(x) + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + return x[:, 0] + + def get_last_selfattention(self, x): + x = self.prepare_tokens(x) + for i, blk in enumerate(self.blocks): + if i < len(self.blocks) - 1: + x = blk(x) + else: + # return attention of the last block + return blk(x, return_attention=True) + + def get_intermediate_layers(self, x, n=1): + x = self.prepare_tokens(x) + # we return the output tokens from the `n` last blocks + output = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if len(self.blocks) - i <= n: + output.append(self.norm(x)) + return output + + +def vit_tiny(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_small(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, #num_heads=6 + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_base(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +class DINOHead(nn.Module): + def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): + super().__init__() + nlayers = max(nlayers, 1) + if nlayers == 1: + self.mlp = nn.Linear(in_dim, bottleneck_dim) + else: + layers = [nn.Linear(in_dim, hidden_dim)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim)) + self.mlp = nn.Sequential(*layers) + self.apply(self._init_weights) + self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + if norm_last_layer: + self.last_layer.weight_g.requires_grad = False + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + x = nn.functional.normalize(x, dim=-1, p=2) + x = self.last_layer(x) + return x diff --git a/train.py b/train.py index c4b4d7d..182e3c0 100644 --- a/train.py +++ b/train.py @@ -3,8 +3,9 @@ from pathlib import Path import numpy as np import glob -from datasets import DataInterface -from models import ModelInterface +from datasets.data_interface import DataInterface, MILDataModule +from models.model_interface import ModelInterface +import models.vision_transformer as vits from utils.utils import * # pytorch_lightning @@ -15,8 +16,8 @@ from pytorch_lightning import Trainer def make_parse(): parser = argparse.ArgumentParser() parser.add_argument('--stage', default='train', type=str) - parser.add_argument('--config', default='Camelyon/TransMIL.yaml',type=str) - parser.add_argument('--gpus', default = [2]) + parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str) + # parser.add_argument('--gpus', default = [2]) parser.add_argument('--fold', default = 0) args = parser.parse_args() return args @@ -34,20 +35,31 @@ def main(cfg): cfg.callbacks = load_callbacks(cfg) #---->Define Data - DataInterface_dict = {'train_batch_size': cfg.Data.train_dataloader.batch_size, - 'train_num_workers': cfg.Data.train_dataloader.num_workers, - 'test_batch_size': cfg.Data.test_dataloader.batch_size, - 'test_num_workers': cfg.Data.test_dataloader.num_workers, - 'dataset_name': cfg.Data.dataset_name, - 'dataset_cfg': cfg.Data,} - dm = DataInterface(**DataInterface_dict) + # DataInterface_dict = {'train_batch_size': cfg.Data.train_dataloader.batch_size, + # 'train_num_workers': cfg.Data.train_dataloader.num_workers, + # 'test_batch_size': cfg.Data.test_dataloader.batch_size, + # 'test_num_workers': cfg.Data.test_dataloader.num_workers, + # 'dataset_name': cfg.Data.dataset_name, + # 'dataset_cfg': cfg.Data,} + # dm = DataInterface(**DataInterface_dict) + home = Path.cwd().parts[1] + DataInterface_dict = { + 'data_root': cfg.Data.data_dir, + 'label_path': cfg.Data.label_file, + 'batch_size': cfg.Data.train_dataloader.batch_size, + 'num_workers': cfg.Data.train_dataloader.num_workers, + 'n_classes': cfg.Model.n_classes, + } + dm = MILDataModule(**DataInterface_dict) + #---->Define Model ModelInterface_dict = {'model': cfg.Model, 'loss': cfg.Loss, 'optimizer': cfg.Optimizer, 'data': cfg.Data, - 'log': cfg.log_path + 'log': cfg.log_path, + 'backbone': cfg.Model.backbone, } model = ModelInterface(**ModelInterface_dict) @@ -57,12 +69,18 @@ def main(cfg): logger=cfg.load_loggers, callbacks=cfg.callbacks, max_epochs= cfg.General.epochs, + min_epochs = 200, gpus=cfg.General.gpus, - amp_level=cfg.General.amp_level, + # gpus = [4], + # strategy='ddp', + amp_backend='native', + # amp_level=cfg.General.amp_level, precision=cfg.General.precision, accumulate_grad_batches=cfg.General.grad_acc, - deterministic=True, - check_val_every_n_epoch=1, + # fast_dev_run = True, + + # deterministic=True, + check_val_every_n_epoch=10, ) #---->train or test @@ -83,7 +101,7 @@ if __name__ == '__main__': #---->update cfg.config = args.config - cfg.General.gpus = args.gpus + # cfg.General.gpus = args.gpus cfg.General.server = args.stage cfg.Data.fold = args.fold diff --git a/utils/__pycache__/__init__.cpython-39.pyc b/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dcd009c36ba1c22656489be43171004bc84df85 GIT binary patch literal 138 zcmYe~<>g{vU|@*eJ0}rDKL!!Vn2~{j!GVE+p_qk%fgyz<m_d`#ZzV$!NEku<($~+( z&rQ{@%t_4CFV8Q^E-pw+PSp=7O3W+v_4Ls%Ey>I&){l?R%*!l^kJl@xyv1RYo1ape MlWGSt=rhPd0F)gc4gdfE literal 0 HcmV?d00001 diff --git a/utils/__pycache__/utils.cpython-39.pyc b/utils/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1eeb958497b23b1f9e9934764083531720a66d0 GIT binary patch literal 2832 zcmYe~<>g{vU|=|P*fHrQHv_|C5C<7EGcYhXFfcF_PhntSNMT4}%wfo7jACR2v6*t1 zqL@+`QkZj?b6KKT7#UIoQdm-0TNtBQQy5a%a@ca&qc|8DQrJ^CS{R}@Qy5Y>bGUN3 zqquW<qIke2aOLpk@<s81+1xq&xdKrFj0`C}DZDKVQGzK9DNH#+Q9@u^I9G&;fsr9s zG)jz-A%!o6zl9-6JcTKkK~wN0$i143w^#xaOEUaG442I0l4M4ZFcgD)!NI`5;0*GK z2Ll5`2}2D-3S$aWFH^rpEn^8|7E=v~WKLo3WvXQ=VX0wiW(;RwVn|_NVPIisW@cmv zW-w$ZwlZK~1j9%Mkeh=UG+F#^F&2UB%*{_p)nvTI=98b8l3Jw6dW$8$AT{q6OJ!ni z&MlUl{KS+ZHU<WUB9K!wnQyVC6(p8q++s<~%t^h)T$!7*lA(x~fq~(dj($ddZmND| zPGX*Zd45rLaY15os(wgOVqUSYr;mPVNoGzlgw!jjyv13RnwSy~vIAszF^FJb<YN?I zEaGQiV9;c|#h#O&o}OA%j1n9mW$duvIK#leP{WYLki}TbSi)4p*vy#4oWckR97cu` zmJ+5KhAh@*#w@mG#uVlnhAj3hj@b-znZy}tnc#Au$Yp6|l4Pi5hL};qB+XF6EX|O@ zTEmdVmBkIxA<a<35YJe`lfqWRki`fRX=Y?(DB;cGt6_u$xi~`&GqR2pMo{AEWvyi? z;mzX5sfx3gxt6ttrG_;flpwhLLRK=};z+H?EG~)9&%VW8kXV$OS5kb7G5+Q6|NsBr zVoA%-NvYz~)z#It^T|(FNXaa+QgG8`xy6!~n45Zwy)3n;I5R)*7H3X=dVG3OVnN0& z!Iadr#L}FS_>6-1+|-hy%w$cbB5_cBN`S&y;1;`kYF=tlV$LlNkS0*v7BMj}Fx+Cz z%}&WIy2V^vQgn+gIX^EgGyN7D#6h=Mee+XNb8c}YB_?Ml<>#e>4ZOwTl30>hBnH-8 zP?VWha*L}VCmtG&w**5{^NRC}obnTkQhdNsdyCUKI1IwnWV^+joR(fB1JVq32*^WG zJWvnCgL#}KNl??GxRQ&@plopBgAjZm-8`Ta3zaJdCCp-wp$tqMj6#fZj3SH*i~`Jj zjC_nN|JYdg82SFQF&AlpLW1cQKiF5!8L7$H1^FQ7^9N-nX9N`Grx&Fb7ds__a*GI> zP)J5<E-3f#x+WIoR0fyi7ZhaXrK6--Rt5$J9$1>4#K6Fi&QQw;O1HI4HB4Ec^a{?- zj0`1ADa<J>H4ItISu81xDXb}My{xs&HH-^bYnT^;(i>X|J1AA9aHO!OaDpLMFMA4i zFJmnWSQQIM6<9w<3QsmuQ6EUOgfoS?gbQR}Gh+%bIDhb^u%<Ak@b_|n)iVY&XbSj+ z6bXVtkE1LR6lKN5>Cgg*;TCglW}YVNE%w~}yv&mPqFbD~nR)RksW~Nyw>SzCOEOdQ zk|9YiDZe<i2;|#aJn2QLsd@30sX00M<+s?tJVP*DX<?+vev3OLwK%ybv!Dc&Sa}Md zAsnBSSaeGinHQgynU`6dk(zRgFDbDkIU^n>d5fhaGdJ~?KtWM{a%ypLW?p)HaY1To z$}M(C8ZXvlyv14qikn-kB}Ivuc~w%?sRjAT8CC{HDYd%QFkf2fnWWXe1QocN97Tqp z0ANqaECOYMTO6PQ4wUYSOhDqC#ffF9@i~dbCAYZ1%#!?q_-u%G1;GO4shQ~+CB^ai zc{!Du+~A~KBn!%DdLTj?L>PliVoNM2NX<(r(gBI+f<=l`i^@`qinKu@a-hV>TBO0i zz!1glTbZ1glLN|T#kaWgOG^q$OX5MUixN%+rMLKEkSkzHiY!6;K}jAdJMe+C158P= z8mJy|0mUs73lke79}^QJ2(p35-)tOQAQ~0_Wckm=!^OnI$ic{gEc26tqsWVafgu@G zT7V1$VHO4k22dFdsw;{k7#JAZ8PY&0oT-JQgrS77gQ1zRnbC!znX!|xglPeDJ7YUj z8YtC+<vW<NSW;Ld86YYaGJ;C!1*|np3mLf?U~<w7wahil!3>(Lekhp?<Oq<(!XW>F zf-s$-h9Oq0ma&8(i?M^DhEWn!?=$u?O=K!$31-k_DgsA3<1OZt%v?>TTa0eE*o*Vi zN^%n`Rx;gU(lfZln5oHli?ISJD7kEMGK))+^K%RAf*BYXK7%}5rJh`rUtAoYnpaYk zUr-rel3$dZp=XnmpPZOeY^R6Nf#F(EH3oLA61YgHVThFixpo26wH=HgH%l@sWQ4dF z><&$)TdZIkZZYTNrx*Ewk_yu;&Wd=b({C|Wg4Gp)O2VQb1_p)@P?7+-7UEyFTdYNi zdFiRQm~&F|ia?UL*h;|V@-5cljKqS}Tg=6!xtc6RfgpY1pft!@0V+&E4T0ifu;H4_ z5ZN%860mKzgs}t!$OXlq<_iO(6r&2G7-LZ!$jOX;ni8POtt2BSGYMSV6oC{)u_dOY zWG0t@TN>bIz%9XoO0a+9b3l13FEcN_NC0H5Du~bm<sE4Z6?*W*4#`lUe0IwSk1{=2 zsiIeuiB!fEf%B#oK8=V>2rBi9z!@EsCm<m!40gI+UY=fBX<jm@tV_(f#gvzKiwnUo z@&Wk-6e{4bMIt~^c#Fd(H$SB`C)Ey=T8lw7HV-ogBM&1N8xJFgr~o4mBM%dY7N}|i E03o{L1^@s6 literal 0 HcmV?d00001 diff --git a/utils/utils.py b/utils/utils.py index 1b7e44f..96ed223 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -14,7 +14,7 @@ def load_loggers(cfg): log_path = cfg.General.log_path Path(log_path).mkdir(exist_ok=True, parents=True) - log_name = Path(cfg.config).parent + log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}' version_name = Path(cfg.config).name[:-5] cfg.log_path = Path(log_path) / log_name / version_name / f'fold{cfg.Data.fold}' print(f'---->Log dir: {cfg.log_path}') @@ -31,8 +31,10 @@ def load_loggers(cfg): #---->load Callback -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar +from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme from pytorch_lightning.callbacks.early_stopping import EarlyStopping + def load_callbacks(cfg): Mycallbacks = [] @@ -47,7 +49,21 @@ def load_callbacks(cfg): verbose=True, mode='min' ) + Mycallbacks.append(early_stop_callback) + progress_bar = RichProgressBar( + theme=RichProgressBarTheme( + description='green_yellow', + progress_bar='green1', + progress_bar_finished='green1', + batch_progress='green_yellow', + time='grey82', + processing_speed='grey82', + metrics='grey82' + + ) + ) + Mycallbacks.append(progress_bar) if cfg.General.server == 'train' : Mycallbacks.append(ModelCheckpoint(monitor = 'val_loss', @@ -64,7 +80,7 @@ def load_callbacks(cfg): import torch import torch.nn.functional as F def cross_entropy_torch(x, y): - x_softmax = [F.softmax(x[i]) for i in range(len(x))] - x_log = torch.tensor([torch.log(x_softmax[i][y[i]]) for i in range(len(y))]) - loss = - torch.sum(x_log) / len(y) + x_softmax = [F.softmax(x[i], dim=0) for i in range(len(x))] + x_log = torch.tensor([torch.log(x_softmax[i][y[i]]) for i in range(y.shape[0])]) + loss = - torch.sum(x_log) / y.shape[0] return loss -- GitLab