diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4cf8dd15619e7c11d325ae0eb80bba874a99f06d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +logs/* \ No newline at end of file diff --git a/Camelyon/TransMIL.yaml b/Camelyon/TransMIL.yaml index 6642b702777078a7c84261101baa1d3844445847..fba514cc00e181fe48a698215a5de1b68034f6c2 100644 --- a/Camelyon/TransMIL.yaml +++ b/Camelyon/TransMIL.yaml @@ -29,9 +29,13 @@ Data: batch_size: 1 num_workers: 8 + + + Model: name: TransMIL n_classes: 2 + backbone: resnet18 Optimizer: diff --git a/DeepGraft/TransMIL.yaml b/DeepGraft/TransMIL.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7945828389d87697cff095a2fcc4c5a062437f90 --- /dev/null +++ b/DeepGraft/TransMIL.yaml @@ -0,0 +1,50 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [0] + epochs: &epoch 200 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_bin.json' + fold: 0 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + + + +Model: + name: TransMIL + n_classes: 6 + backbone: resnet18 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_dino.yaml b/DeepGraft/TransMIL_dino.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ffe987c913ff9075730d7b16b8e9dba16d5d4978 --- /dev/null +++ b/DeepGraft/TransMIL_dino.yaml @@ -0,0 +1,50 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [0] + epochs: &epoch 1000 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_bin.json' + fold: 0 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + + + +Model: + name: TransMIL + n_classes: 2 + backbone: dino + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_resnet18_all.yaml b/DeepGraft/TransMIL_resnet18_all.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8fa5818981b31c1a6c36f34b42e055f2198681fc --- /dev/null +++ b/DeepGraft/TransMIL_resnet18_all.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_all.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 6 + backbone: resnet18 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_resnet18_no_other.yaml b/DeepGraft/TransMIL_resnet18_no_other.yaml new file mode 100644 index 0000000000000000000000000000000000000000..95a9bd64692f5ee12dd822e04fea889b80717457 --- /dev/null +++ b/DeepGraft/TransMIL_resnet18_no_other.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [4] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_other.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 5 + backbone: resnet18 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_resnet18_no_viral.yaml b/DeepGraft/TransMIL_resnet18_no_viral.yaml new file mode 100644 index 0000000000000000000000000000000000000000..155b676a24e0541f42f8fee12af2d987217c0525 --- /dev/null +++ b/DeepGraft/TransMIL_resnet18_no_viral.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 4 + backbone: resnet18 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e7d9bf0694a227f987d1d2fbf2f0facb53c248d5 --- /dev/null +++ b/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 2 + backbone: resnet18 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_resnet50.yaml b/DeepGraft/TransMIL_resnet50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d76b2cf618dc0f3f249f9aac787eb471793bd49f --- /dev/null +++ b/DeepGraft/TransMIL_resnet50.yaml @@ -0,0 +1,50 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 32 + multi_gpu_mode: dp + gpus: [1] + epochs: &epoch 200 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_bin.json' + fold: 0 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + + + +Model: + name: TransMIL + n_classes: 2 + backbone: resnet50 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_resnet50_all.yaml b/DeepGraft/TransMIL_resnet50_all.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eba3a4fa20870ad4fc2b173ccb4e60086ddb3ac5 --- /dev/null +++ b/DeepGraft/TransMIL_resnet50_all.yaml @@ -0,0 +1,50 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [1] + epochs: &epoch 1000 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: train #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_all.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + + + +Model: + name: TransMIL + n_classes: 6 + backbone: resnet50 + + +Optimizer: + opt: lookahead_radam + lr: 0.0001 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_simple.yaml b/DeepGraft/TransMIL_simple.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4501f2c06242bb6cb951fef458eab000daac0d75 --- /dev/null +++ b/DeepGraft/TransMIL_simple.yaml @@ -0,0 +1,50 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [1] + epochs: &epoch 200 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_bin.json' + fold: 0 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + + + +Model: + name: TransMIL + n_classes: 2 + backbone: simple + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/MyLoss/__pycache__/ND_Crossentropy.cpython-39.pyc b/MyLoss/__pycache__/ND_Crossentropy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b39de626d6dab873f33584529f0c3a4bef6bf961 Binary files /dev/null and b/MyLoss/__pycache__/ND_Crossentropy.cpython-39.pyc differ diff --git a/MyLoss/__pycache__/__init__.cpython-39.pyc b/MyLoss/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..119cc4d4a02cb0c100c171e5e23385ccb8309724 Binary files /dev/null and b/MyLoss/__pycache__/__init__.cpython-39.pyc differ diff --git a/MyLoss/__pycache__/boundary_loss.cpython-39.pyc b/MyLoss/__pycache__/boundary_loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34040291007bdff88adfcea8434647cf1d47f3a2 Binary files /dev/null and b/MyLoss/__pycache__/boundary_loss.cpython-39.pyc differ diff --git a/MyLoss/__pycache__/dice_loss.cpython-39.pyc b/MyLoss/__pycache__/dice_loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23837a13153c252543d3e48776b1fbdea5eff0a2 Binary files /dev/null and b/MyLoss/__pycache__/dice_loss.cpython-39.pyc differ diff --git a/MyLoss/__pycache__/focal_loss.cpython-39.pyc b/MyLoss/__pycache__/focal_loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b2734f4edcae0885494fb2002e9678f127ccc8d Binary files /dev/null and b/MyLoss/__pycache__/focal_loss.cpython-39.pyc differ diff --git a/MyLoss/__pycache__/hausdorff.cpython-39.pyc b/MyLoss/__pycache__/hausdorff.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4472c5317ce1e97640f7705acdd550d6ff75b161 Binary files /dev/null and b/MyLoss/__pycache__/hausdorff.cpython-39.pyc differ diff --git a/MyLoss/__pycache__/loss_factory.cpython-39.pyc b/MyLoss/__pycache__/loss_factory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed7437016bbcadc48f9bb0a97401529dc574c9b2 Binary files /dev/null and b/MyLoss/__pycache__/loss_factory.cpython-39.pyc differ diff --git a/MyLoss/__pycache__/lovasz_loss.cpython-39.pyc b/MyLoss/__pycache__/lovasz_loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b50bae1fae55772838d508428f39b9eb9d4ca90b Binary files /dev/null and b/MyLoss/__pycache__/lovasz_loss.cpython-39.pyc differ diff --git a/MyLoss/loss_factory.py b/MyLoss/loss_factory.py index 2394abe78706f535cf7e20da871031386b72005a..1dffa6182ef40b1f46436b8a1eadb0a8906c17ff 100755 --- a/MyLoss/loss_factory.py +++ b/MyLoss/loss_factory.py @@ -34,8 +34,6 @@ def create_loss(args, w1=1.0, w2=0.5): loss = L.BinaryDiceLoss() elif conf_loss == "dice_log": loss = L.BinaryDiceLogLoss() - elif conf_loss == "dice_log": - loss = L.BinaryDiceLogLoss() elif conf_loss == "bce+lovasz": loss = L.JointLoss(BCEWithLogitsLoss(), L.BinaryLovaszLoss(), w1, w2) elif conf_loss == "lovasz": @@ -62,6 +60,7 @@ def make_parse(): if __name__ == '__main__': args = make_parse() myloss = create_loss(args) + print(myloss) data = torch.randn(2, 3) label = torch.empty(2, dtype=torch.long).random_(3) loss = myloss(data, label) \ No newline at end of file diff --git a/MyOptimizer/__pycache__/__init__.cpython-39.pyc b/MyOptimizer/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..058b9be16d0a0efced0d6e4650eb2c5eb3a2450b Binary files /dev/null and b/MyOptimizer/__pycache__/__init__.cpython-39.pyc differ diff --git a/MyOptimizer/__pycache__/adafactor.cpython-39.pyc b/MyOptimizer/__pycache__/adafactor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5594d367f0f462a2976347a368a45da43954538c Binary files /dev/null and b/MyOptimizer/__pycache__/adafactor.cpython-39.pyc differ diff --git a/MyOptimizer/__pycache__/adahessian.cpython-39.pyc b/MyOptimizer/__pycache__/adahessian.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc22c127ae8d56a626d590c23499817208854192 Binary files /dev/null and b/MyOptimizer/__pycache__/adahessian.cpython-39.pyc differ diff --git a/MyOptimizer/__pycache__/adamp.cpython-39.pyc b/MyOptimizer/__pycache__/adamp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ba6bf0c26ff7652d56f9139e6b93a9ef69f58aa Binary files /dev/null and b/MyOptimizer/__pycache__/adamp.cpython-39.pyc differ diff --git a/MyOptimizer/__pycache__/adamw.cpython-39.pyc b/MyOptimizer/__pycache__/adamw.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cae0384996e05f7a7498980d9d36881eb3625d2e Binary files /dev/null and b/MyOptimizer/__pycache__/adamw.cpython-39.pyc differ diff --git a/MyOptimizer/__pycache__/lookahead.cpython-39.pyc b/MyOptimizer/__pycache__/lookahead.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ae677bc3ab72ad3f7a4591ee097c38fb2755106 Binary files /dev/null and b/MyOptimizer/__pycache__/lookahead.cpython-39.pyc differ diff --git a/MyOptimizer/__pycache__/nadam.cpython-39.pyc b/MyOptimizer/__pycache__/nadam.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bcbfa6a76f3d6e00685eef8c40c0840e26ae884 Binary files /dev/null and b/MyOptimizer/__pycache__/nadam.cpython-39.pyc differ diff --git a/MyOptimizer/__pycache__/novograd.cpython-39.pyc b/MyOptimizer/__pycache__/novograd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9f49e4a9cd08791de1aafb9c381117c17a9d9ee Binary files /dev/null and b/MyOptimizer/__pycache__/novograd.cpython-39.pyc differ diff --git a/MyOptimizer/__pycache__/nvnovograd.cpython-39.pyc b/MyOptimizer/__pycache__/nvnovograd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3528945920d301895ec14ba380f4b1132b62478e Binary files /dev/null and b/MyOptimizer/__pycache__/nvnovograd.cpython-39.pyc differ diff --git a/MyOptimizer/__pycache__/optim_factory.cpython-39.pyc b/MyOptimizer/__pycache__/optim_factory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6bfac29b2f804c5be7a3f18b97fe6d9401a3cdc Binary files /dev/null and b/MyOptimizer/__pycache__/optim_factory.cpython-39.pyc differ diff --git a/MyOptimizer/__pycache__/radam.cpython-39.pyc b/MyOptimizer/__pycache__/radam.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11a200d1a53e2fa7f59e54daaae3f30c6043ce52 Binary files /dev/null and b/MyOptimizer/__pycache__/radam.cpython-39.pyc differ diff --git a/MyOptimizer/__pycache__/rmsprop_tf.cpython-39.pyc b/MyOptimizer/__pycache__/rmsprop_tf.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70fc2eb4e1c4a2ff4658ffaa59558f13219a5a73 Binary files /dev/null and b/MyOptimizer/__pycache__/rmsprop_tf.cpython-39.pyc differ diff --git a/MyOptimizer/__pycache__/sgdp.cpython-39.pyc b/MyOptimizer/__pycache__/sgdp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b747bafd446e4ab8cc65806fbf09cb1f38430e4 Binary files /dev/null and b/MyOptimizer/__pycache__/sgdp.cpython-39.pyc differ diff --git a/MyOptimizer/lookahead.py b/MyOptimizer/lookahead.py index 6b5b7f38ec8cb6594e3986b66223fa2881daeca3..b8e8b0095e8da7a28b8742df9e0322996941bb74 100755 --- a/MyOptimizer/lookahead.py +++ b/MyOptimizer/lookahead.py @@ -35,7 +35,9 @@ class Lookahead(Optimizer): param_state['slow_buffer'] = torch.empty_like(fast_p.data) param_state['slow_buffer'].copy_(fast_p.data) slow = param_state['slow_buffer'] - slow.add_(group['lookahead_alpha'], fast_p.data - slow) + # slow.add_(group['lookahead_alpha'], fast_p.data - slow) + slow.add_(fast_p.data-slow, alpha=group['lookahead_alpha']) + fast_p.data.copy_(slow) def sync_lookahead(self): diff --git a/MyOptimizer/optim_factory.py b/MyOptimizer/optim_factory.py index ce310e3f593680b579369b51d047a681b41ce351..992231aab94e896725de99e636595e3b0ce2ebe7 100755 --- a/MyOptimizer/optim_factory.py +++ b/MyOptimizer/optim_factory.py @@ -75,7 +75,8 @@ def create_optimizer(args, model, filter_bias_and_bn=True): elif opt_lower == 'nadam': optimizer = Nadam(parameters, **opt_args) elif opt_lower == 'radam': - optimizer = RAdam(parameters, **opt_args) + # optimizer = RAdam(parameters, **opt_args) + optimizer = optim.RAdam(parameters, **opt_args) elif opt_lower == 'adamp': optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) elif opt_lower == 'sgdp': diff --git a/datasets/__pycache__/__init__.cpython-39.pyc b/datasets/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13006df55083dc070f46953e4505bfef1ac8b198 Binary files /dev/null and b/datasets/__pycache__/__init__.cpython-39.pyc differ diff --git a/datasets/__pycache__/camel_data.cpython-39.pyc b/datasets/__pycache__/camel_data.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ffc96a2ce0ccb33cabde901455fd1b7c9c44811 Binary files /dev/null and b/datasets/__pycache__/camel_data.cpython-39.pyc differ diff --git a/datasets/__pycache__/camel_dataloader.cpython-39.pyc b/datasets/__pycache__/camel_dataloader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11da2c2e9cc3f2bb97782008855d83b10df237c4 Binary files /dev/null and b/datasets/__pycache__/camel_dataloader.cpython-39.pyc differ diff --git a/datasets/__pycache__/custom_dataloader.cpython-39.pyc b/datasets/__pycache__/custom_dataloader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..147b3fc3c4628d6eed5bebd6ff248d31aaeebb34 Binary files /dev/null and b/datasets/__pycache__/custom_dataloader.cpython-39.pyc differ diff --git a/datasets/__pycache__/data_interface.cpython-39.pyc b/datasets/__pycache__/data_interface.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4af0141496c11d6f97255add8746fa0067e55636 Binary files /dev/null and b/datasets/__pycache__/data_interface.cpython-39.pyc differ diff --git a/datasets/camel_dataloader.py b/datasets/camel_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..302cabd491bf81619af8c11692ae560ea81bf410 --- /dev/null +++ b/datasets/camel_dataloader.py @@ -0,0 +1,126 @@ +import pandas as pd + +import numpy as np +import torch +from torch import Tensor +from torch.autograd import Variable +from torch.nn.functional import one_hot +import torch.utils.data as data_utils +from torchvision import datasets, transforms +import pandas as pd +from sklearn.utils import shuffle +from pathlib import Path +from tqdm import tqdm + + +class FeatureBagLoader(data_utils.Dataset): + def __init__(self, data_root,train=True, cache=True): + + bags_path = pd.read_csv(data_root) + + self.train_path = bags_path.iloc[0:int(len(bags_path)*0.8), :] + self.test_path = bags_path.iloc[int(len(bags_path)*0.8):, :] + # self.train_path = shuffle(train_path).reset_index(drop=True) + # self.test_path = shuffle(test_path).reset_index(drop=True) + + home = Path.cwd().parts[1] + self.origin_path = Path(f'/{home}/ylan/RCC_project/rcc_classification/') + # self.target_number = target_number + # self.mean_bag_length = mean_bag_length + # self.var_bag_length = var_bag_length + # self.num_bag = num_bag + self.cache = cache + self.train = train + self.n_classes = 2 + + self.features = [] + self.labels = [] + if self.cache: + if train: + with tqdm(total=len(self.train_path)) as pbar: + for t in tqdm(self.train_path.iloc()): + ft, lbl = self.get_bag_feats(t) + # ft = ft.view(-1, 512) + + self.labels.append(lbl) + self.features.append(ft) + pbar.update() + else: + with tqdm(total=len(self.test_path)) as pbar: + for t in tqdm(self.test_path.iloc()): + ft, lbl = self.get_bag_feats(t) + # lbl = Variable(Tensor(lbl)) + # ft = Variable(Tensor(ft)).view(-1, 512) + self.labels.append(lbl) + self.features.append(ft) + pbar.update() + # print(self.get_bag_feats(self.train_path)) + # self.r = np.random.RandomState(seed) + + # self.num_in_train = 60000 + # self.num_in_test = 10000 + + # if self.train: + # self.train_bags_list, self.train_labels_list = self._create_bags() + # else: + # self.test_bags_list, self.test_labels_list = self._create_bags() + + def get_bag_feats(self, csv_file_df): + # if args.dataset == 'TCGA-lung-default': + # feats_csv_path = 'datasets/tcga-dataset/tcga_lung_data_feats/' + csv_file_df.iloc[0].split('/')[1] + '.csv' + # else: + + feats_csv_path = self.origin_path / csv_file_df.iloc[0] + df = pd.read_csv(feats_csv_path) + # feats = shuffle(df).reset_index(drop=True) + # feats = feats.to_numpy() + feats = df.to_numpy() + label = np.zeros(self.n_classes) + if self.n_classes==2: + label[1] = csv_file_df.iloc[1] + else: + if int(csv_file_df.iloc[1])<=(len(label)-1): + label[int(csv_file_df.iloc[1])] = 1 + + return feats, label + + def __len__(self): + if self.train: + return len(self.train_path) + else: + return len(self.test_path) + + def __getitem__(self, index): + + if self.cache: + label = self.labels[index] + feats = self.features[index] + label = Variable(Tensor(label)) + feats = Variable(Tensor(feats)).view(-1, 512) + return feats, label + else: + if self.train: + feats, label = self.get_bag_feats(self.train_path.iloc[index]) + label = Variable(Tensor(label)) + feats = Variable(Tensor(feats)).view(-1, 512) + else: + feats, label = self.get_bag_feats(self.test_path.iloc[index]) + label = Variable(Tensor(label)) + feats = Variable(Tensor(feats)).view(-1, 512) + + return feats, label + +if __name__ == '__main__': + import os + cwd = os.getcwd() + home = cwd.split('/')[1] + data_root = f'/{home}/ylan/RCC_project/rcc_classification/datasets/Camelyon16/Camelyon16.csv' + dataset = FeatureBagLoader(data_root, cache=False) + for i in dataset: + # print(i[1]) + # print(i) + + features, label = i + print(label) + # print(features.shape) + # print(label[0].long()) \ No newline at end of file diff --git a/datasets/custom_dataloader.py b/datasets/custom_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..ddc2ed36c64fc46f0844673514b51cead0052bc6 --- /dev/null +++ b/datasets/custom_dataloader.py @@ -0,0 +1,332 @@ +import h5py +# import helpers +import numpy as np +from pathlib import Path +import torch +# from torch._C import long +from torch.utils import data +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm +# from histoTransforms import RandomHueSaturationValue +import torchvision.transforms as transforms +import torch.nn.functional as F +import csv +from PIL import Image +import cv2 +import pandas as pd +import json + +class HDF5MILDataloader(data.Dataset): + """Represents an abstract HDF5 dataset. For single H5 container! + + Input params: + file_path: Path to the folder containing the dataset (one or multiple HDF5 files). + mode: 'train' or 'test' + load_data: If True, loads all the data immediately into RAM. Use this if + the dataset is fits into memory. Otherwise, leave this at false and + the data will load lazily. + data_cache_size: Number of HDF5 files that can be cached in the cache (default=3). + + """ + def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=20): + super().__init__() + + self.data_info = [] + self.data_cache = {} + self.slideLabelDict = {} + self.data_cache_size = data_cache_size + self.mode = mode + self.file_path = file_path + # self.csv_path = csv_path + self.label_path = label_path + self.n_classes = n_classes + self.bag_size = 120 + # self.label_file = label_path + recursive = True + + # read labels and slide_path from csv + + # df = pd.read_csv(self.csv_path) + # labels = df.LABEL + # slides = df.FILENAME + with open(self.label_path, 'r') as f: + self.slideLabelDict = json.load(f)[mode] + + self.slideLabelDict = {Path(x).stem : y for (x,y) in self.slideLabelDict} + + + # if Path(slides[0]).suffix: + # slides = list(map(lambda x: Path(x).stem, slides)) + + # print(labels) + # print(slides) + # self.slideLabelDict = dict(zip(slides, labels)) + # print(self.slideLabelDict) + + #check if files in slideLabelDict, only take files that are available. + + files_in_path = list(Path(self.file_path).rglob('*.hdf5')) + files_in_path = [x.stem for x in files_in_path] + # print(len(files_in_path)) + # print(files_in_path) + # print(list(self.slideLabelDict.keys())) + # for x in list(self.slideLabelDict.keys()): + # if x in files_in_path: + # path = Path(self.file_path) / (x + '.hdf5') + # print(path) + + self.files = [Path(self.file_path)/ (x + '.hdf5') for x in list(self.slideLabelDict.keys()) if x in files_in_path] + + print(len(self.files)) + # self.files = list(map(lambda x: Path(self.file_path) / (Path(x).stem + '.hdf5'), list(self.slideLabelDict.keys()))) + + for h5dataset_fp in tqdm(self.files): + # print(h5dataset_fp) + self._add_data_infos(str(h5dataset_fp.resolve()), load_data) + + # print(self.data_info) + self.resize_transforms = transforms.Compose([ + transforms.ToPILImage(), + transforms.Resize(256), + ]) + + self.img_transforms = transforms.Compose([ + transforms.RandomHorizontalFlip(p=1), + transforms.RandomVerticalFlip(p=1), + # histoTransforms.AutoRandomRotation(), + transforms.Lambda(lambda a: np.array(a)), + ]) + self.hsv_transforms = transforms.Compose([ + RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)), + transforms.ToTensor() + ]) + + # self._add_data_infos(load_data) + + + def __getitem__(self, index): + # get data + batch, label, name = self.get_data(index) + out_batch = [] + + if self.mode == 'train': + # print(img) + # print(img.shape) + for img in batch: + img = self.img_transforms(img) + img = self.hsv_transforms(img) + out_batch.append(img) + + else: + for img in batch: + img = transforms.functional.to_tensor(img) + out_batch.append(img) + if len(out_batch) == 0: + # print(name) + out_batch = torch.randn(100,3,256,256) + else: out_batch = torch.stack(out_batch) + out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch + + label = torch.as_tensor(label) + label = torch.nn.functional.one_hot(label, num_classes=self.n_classes) + return out_batch, label, name + + def __len__(self): + return len(self.data_info) + + def _add_data_infos(self, file_path, load_data): + wsi_name = Path(file_path).stem + if wsi_name in self.slideLabelDict: + label = self.slideLabelDict[wsi_name] + wsi_batch = [] + # with h5py.File(file_path, 'r') as h5_file: + # numKeys = len(h5_file.keys()) + # sample = list(h5_file.keys())[0] + # shape = (numKeys,) + h5_file[sample][:].shape + # for tile in h5_file.keys(): + # img = h5_file[tile][:] + + # print(img) + # if type == 'images': + # t = 'data' + # else: + # t = 'label' + idx = -1 + # if load_data: + # for tile in h5_file.keys(): + # img = h5_file[tile][:] + # img = img.astype(np.uint8) + # img = self.resize_transforms(img) + # wsi_batch.append(img) + # idx = self._add_to_cache(wsi_batch, file_path) + # wsi_batch.append(img) + # self.data_info.append({'data_path': file_path, 'label': label, 'shape': shape, 'name': wsi_name, 'cache_idx': idx}) + self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx}) + + def _load_data(self, file_path): + """Load data to the cache given the file + path and update the cache index in the + data_info structure. + """ + with h5py.File(file_path, 'r') as h5_file: + wsi_batch = [] + for tile in h5_file.keys(): + img = h5_file[tile][:] + img = img.astype(np.uint8) + img = self.resize_transforms(img) + wsi_batch.append(img) + idx = self._add_to_cache(wsi_batch, file_path) + file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path) + self.data_info[file_idx + idx]['cache_idx'] = idx + + # for type in ['images', 'labels']: + # for key in tqdm(h5_file[f'{self.mode}/{type}'].keys()): + # img = h5_file[data_path][:] + # idx = self._add_to_cache(img, data_path) + # file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == data_path) + # self.data_info[file_idx + idx]['cache_idx'] = idx + # for gname, group in h5_file.items(): + # for dname, ds in group.items(): + # # add data to the data cache and retrieve + # # the cache index + # idx = self._add_to_cache(ds.value, file_path) + + # # find the beginning index of the hdf5 file we are looking for + # file_idx = next(i for i,v in enumerate(self.data_info) if v['file_path'] == file_path) + + # # the data info should have the same index since we loaded it in the same way + # self.data_info[file_idx + idx]['cache_idx'] = idx + + # remove an element from data cache if size was exceeded + if len(self.data_cache) > self.data_cache_size: + # remove one item from the cache at random + removal_keys = list(self.data_cache) + removal_keys.remove(file_path) + self.data_cache.pop(removal_keys[0]) + # remove invalid cache_idx + # self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'shape': di['shape'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info] + self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info] + + def _add_to_cache(self, data, data_path): + """Adds data to the cache and returns its index. There is one cache + list for every file_path, containing all datasets in that file. + """ + if data_path not in self.data_cache: + self.data_cache[data_path] = [data] + else: + self.data_cache[data_path].append(data) + return len(self.data_cache[data_path]) - 1 + + # def get_data_infos(self, type): + # """Get data infos belonging to a certain type of data. + # """ + # data_info_type = [di for di in self.data_info if di['type'] == type] + # return data_info_type + + def get_name(self, i): + # name = self.get_data_infos(type)[i]['name'] + name = self.data_info[i]['name'] + return name + + def get_data(self, i): + """Call this function anytime you want to access a chunk of data from the + dataset. This will make sure that the data is loaded in case it is + not part of the data cache. + i = index + """ + # fp = self.get_data_infos(type)[i]['data_path'] + fp = self.data_info[i]['data_path'] + if fp not in self.data_cache: + self._load_data(fp) + + # get new cache_idx assigned by _load_data_info + # cache_idx = self.get_data_infos(type)[i]['cache_idx'] + cache_idx = self.data_info[i]['cache_idx'] + label = self.data_info[i]['label'] + name = self.data_info[i]['name'] + # print(self.data_cache[fp][cache_idx]) + return self.data_cache[fp][cache_idx], label, name + + +class RandomHueSaturationValue(object): + + def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5): + + self.hue_shift_limit = hue_shift_limit + self.sat_shift_limit = sat_shift_limit + self.val_shift_limit = val_shift_limit + self.p = p + + def __call__(self, sample): + + img = sample #,lbl + + if np.random.random() < self.p: + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32 + h, s, v = cv2.split(img) + hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1) + hue_shift = np.uint8(hue_shift) + h += hue_shift + sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1]) + s = cv2.add(s, sat_shift) + val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1]) + v = cv2.add(v, val_shift) + img = cv2.merge((h, s, v)) + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) + return img #, lbl + + + +if __name__ == '__main__': + from pathlib import Path + import os + + home = Path.cwd().parts[1] + train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv' + data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json' + label_path = f'/{home}/ylan/DeepGraft/training_tables/split_Aachen_PAS_all.json' + output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/' + # os.makedirs(output_path, exist_ok=True) + + + dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=6) + data = DataLoader(dataset, batch_size=1) + + # print(len(dataset)) + x = 0 + c = 0 + for item in data: + if c >=10: + break + bag, label, name = item + print(bag) + # # print(bag.shape) + # if bag.shape[1] == 1: + # print(name) + # print(bag.shape) + # print(bag.shape) + # print(name) + # out_dir = Path(output_path) / name + # os.makedirs(out_dir, exist_ok=True) + + # # print(item[2]) + # # print(len(item)) + # # print(item[1]) + # # print(data.shape) + # # data = data.squeeze() + # bag = item[0] + # bag = bag.squeeze() + # for i in range(bag.shape[0]): + # img = bag[i, :, :, :] + # img = img.squeeze() + # img = img*255 + # img = img.numpy().astype(np.uint8).transpose(1,2,0) + + # img = Image.fromarray(img) + # img = img.convert('RGB') + # img.save(f'{out_dir}/{i}.png') + c += 1 + # else: break + # print(data.shape) + # print(label) \ No newline at end of file diff --git a/datasets/data_interface.py b/datasets/data_interface.py index 3952e5bae7d77d2f67f6e417906e7accf9c00517..12a0f8c450b4945658d3566cc5c157116bccab66 100644 --- a/datasets/data_interface.py +++ b/datasets/data_interface.py @@ -1,9 +1,13 @@ import inspect # 查看python 类的参数和模块、函数代码 import importlib # In order to dynamically import the library +from typing import Optional import pytorch_lightning as pl from torch.utils.data import random_split, DataLoader from torchvision.datasets import MNIST from torchvision import transforms +from .camel_dataloader import FeatureBagLoader +from .custom_dataloader import HDF5MILDataloader +from pathlib import Path class DataInterface(pl.LightningDataModule): @@ -24,6 +28,8 @@ class DataInterface(pl.LightningDataModule): self.dataset_name = dataset_name self.kwargs = kwargs self.load_data_module() + home = Path.cwd().parts[1] + self.data_root = f'/{home}/ylan/RCC_project/rcc_classification/datasets/Camelyon16/Camelyon16.csv' @@ -46,14 +52,23 @@ class DataInterface(pl.LightningDataModule): """ # Assign train/val datasets for use in dataloaders if stage == 'fit' or stage is None: - self.train_dataset = self.instancialize(state='train') - self.val_dataset = self.instancialize(state='val') + dataset = FeatureBagLoader(data_root = self.data_root, + train=True) + a = int(len(dataset)* 0.8) + b = int(len(dataset) - a) + print(a) + print(b) + self.train_dataset, self.val_dataset = random_split(dataset, [a, b]) + # self.train_dataset = self.instancialize(state='train') + # self.val_dataset = self.instancialize(state='val') # Assign test dataset for use in dataloader(s) if stage == 'test' or stage is None: # self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform) - self.test_dataset = self.instancialize(state='test') + self.test_dataset = FeatureBagLoader(data_root = self.data_root, + train=False) + # self.test_dataset = self.instancialize(state='test') def train_dataloader(self): @@ -87,4 +102,62 @@ class DataInterface(pl.LightningDataModule): if arg in inkeys: args1[arg] = self.kwargs[arg] args1.update(other_args) - return self.data_module(**args1) \ No newline at end of file + return self.data_module(**args1) + +class MILDataModule(pl.LightningDataModule): + + def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, *args, **kwargs): + super().__init__() + self.data_root = data_root + self.label_path = label_path + self.batch_size = batch_size + self.num_workers = num_workers + self.image_size = 384 + self.n_classes = n_classes + self.target_number = 9 + self.mean_bag_length = 10 + self.var_bag_length = 2 + self.num_bags_train = 200 + self.num_bags_test = 50 + self.seed = 1 + + self.cache = True + + + def setup(self, stage: Optional[str] = None) -> None: + # if self.n_classes == 2: + # if stage in (None, 'fit'): + # dataset = HDF5Dataset(self.data_root, mode='train', n_classes=self.n_classes) + # a = int(len(dataset)* 0.8) + # b = int(len(dataset) - a) + # self.train_data, self.valid_data = random_split(dataset, [a, b]) + + # if stage in (None, 'test'): + # self.test_data = HDF5Dataset(self.data_root, mode='test', n_classes=self.n_classes) + # else: + home = Path.cwd().parts[1] + # self.label_path = f'{home}/ylan/DeepGraft_project/code/split_debug.json' + # train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train_small.csv' + # test_csv = f'/{home}/ylan/DeepGraft_project/code/debug_test_small.csv' + + + if stage in (None, 'fit'): + dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes) + # print(len(dataset)) + a = int(len(dataset)* 0.8) + b = int(len(dataset) - a) + self.train_data, self.valid_data = random_split(dataset, [a, b]) + + if stage in (None, 'test'): + self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes) + + return super().setup(stage=stage) + + def train_dataloader(self) -> DataLoader: + return DataLoader(self.train_data, self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, + + def val_dataloader(self) -> DataLoader: + return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers) + + def test_dataloader(self) -> DataLoader: + return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers) \ No newline at end of file diff --git a/models/TransMIL.py b/models/TransMIL.py index 3cb4e52c6ce1802bcddcde727aa80c260efe2e76..ce40a26b37b1886bf5698ee4ab8ecf07c1e4e2c8 100755 --- a/models/TransMIL.py +++ b/models/TransMIL.py @@ -47,7 +47,8 @@ class TransMIL(nn.Module): def __init__(self, n_classes): super(TransMIL, self).__init__() self.pos_layer = PPEG(dim=512) - self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU()) + self._fc1 = nn.Sequential(nn.Linear(512, 512), nn.ReLU()) + # self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU()) self.cls_token = nn.Parameter(torch.randn(1, 1, 512)) self.n_classes = n_classes self.layer1 = TransLayer(dim=512) @@ -56,11 +57,10 @@ class TransMIL(nn.Module): self._fc2 = nn.Linear(512, self.n_classes) - def forward(self, **kwargs): + def forward(self, **kwargs): #, **kwargs h = kwargs['data'].float() #[B, n, 1024] - - h = self._fc1(h) #[B, n, 512] + # h = self._fc1(h) #[B, n, 512] #---->pad H = h.shape[1] @@ -86,15 +86,19 @@ class TransMIL(nn.Module): h = self.norm(h)[:,0] #---->predict - logits = self._fc2(h) #[B, n_classes] + logits = self._fc2(torch.sigmoid(h)) #[B, n_classes] Y_hat = torch.argmax(logits, dim=1) Y_prob = F.softmax(logits, dim = 1) results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat} return results_dict if __name__ == "__main__": - data = torch.randn((1, 6000, 1024)).cuda() + data = torch.randn((1, 6000, 512)).cuda() model = TransMIL(n_classes=2).cuda() print(model.eval()) results_dict = model(data = data) print(results_dict) + logits = results_dict['logits'] + Y_prob = results_dict['Y_prob'] + Y_hat = results_dict['Y_hat'] + # print(F.sigmoid(logits)) diff --git a/models/__init__.py b/models/__init__.py index 497cee19810dc56d0933816137b554d7b3c760cc..73aad9d74d278565976a1dd2c63c69dc7a0997ad 100755 --- a/models/__init__.py +++ b/models/__init__.py @@ -1 +1 @@ -from .model_interface import ModelInterface \ No newline at end of file +from .model_interface import ModelInterface diff --git a/models/__pycache__/TransMIL.cpython-39.pyc b/models/__pycache__/TransMIL.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e1ddff6d6f3cbadfd7f1c0c4686a7806f5896e8 Binary files /dev/null and b/models/__pycache__/TransMIL.cpython-39.pyc differ diff --git a/models/__pycache__/__init__.cpython-39.pyc b/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45428bd7fc4c40dc9fb1116df4a96aea2fb568aa Binary files /dev/null and b/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/models/__pycache__/model_interface.cpython-39.pyc b/models/__pycache__/model_interface.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e81c0ccdd33e6fe68a5ffe1b602990134944c80 Binary files /dev/null and b/models/__pycache__/model_interface.cpython-39.pyc differ diff --git a/models/__pycache__/vision_transformer.cpython-39.pyc b/models/__pycache__/vision_transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e278bdeba138e438387aa38b7a83ca7bb7819a7e Binary files /dev/null and b/models/__pycache__/vision_transformer.cpython-39.pyc differ diff --git a/models/model_interface.py b/models/model_interface.py index c7bb72323c99a57672eff72f7e660ba90c269594..1b0f6e19f8429b764e6fd45b96bf821981d0731e 100755 --- a/models/model_interface.py +++ b/models/model_interface.py @@ -4,6 +4,9 @@ import inspect import importlib import random import pandas as pd +import seaborn as sns +from pathlib import Path +from matplotlib import pyplot as plt #----> from MyOptimizer import create_optimizer @@ -18,9 +21,11 @@ import torchmetrics #----> import pytorch_lightning as pl +from .vision_transformer import vit_small +from torchvision import models +from torchvision.models import resnet - -class ModelInterface(pl.LightningModule): +class ModelInterface(pl.LightningModule): #---->init def __init__(self, model, loss, optimizer, **kargs): @@ -37,11 +42,11 @@ class ModelInterface(pl.LightningModule): #---->Metrics if self.n_classes > 2: - self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'macro') + self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted') metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes, average='micro'), torchmetrics.CohenKappa(num_classes = self.n_classes), - torchmetrics.F1(num_classes = self.n_classes, + torchmetrics.F1Score(num_classes = self.n_classes, average = 'macro'), torchmetrics.Recall(average = 'macro', num_classes = self.n_classes), @@ -49,17 +54,19 @@ class ModelInterface(pl.LightningModule): num_classes = self.n_classes), torchmetrics.Specificity(average = 'macro', num_classes = self.n_classes)]) + else : - self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'macro') + self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'weighted') metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2, average = 'micro'), torchmetrics.CohenKappa(num_classes = 2), - torchmetrics.F1(num_classes = 2, + torchmetrics.F1Score(num_classes = 2, average = 'macro'), torchmetrics.Recall(average = 'macro', num_classes = 2), torchmetrics.Precision(average = 'macro', num_classes = 2)]) + self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes) self.valid_metrics = metrics.clone(prefix = 'val_') self.test_metrics = metrics.clone(prefix = 'test_') @@ -67,18 +74,103 @@ class ModelInterface(pl.LightningModule): self.shuffle = kargs['data'].data_shuffle self.count = 0 + self.out_features = 512 + if kargs['backbone'] == 'dino': + #---> dino feature extractor + arch = 'vit_small' + patch_size = 16 + n_last_blocks = 4 + # num_labels = 1000 + avgpool_patchtokens = False + home = Path.cwd().parts[1] + + weight_path = f'/{home}/ylan/workspace/dino/output/Aachen_2/checkpoint.pth' + model = vit_small(patch_size, num_classes=0) + # model.eval() + # set_parameter_requires_grad(model, feature_extracting) + for param in model.parameters(): + param.requires_grad = False + # print(model.embed_dim) + # embed_dim = model.embed_dim * (n_last_blocks + int(avgpool_patchtokens)) + # model.eval() + # print(embed_dim) + linear = nn.Linear(model.embed_dim, self.out_features) + linear.weight.data.normal_(mean=0.0, std=0.01) + linear.bias.data.zero_() + + self.model_ft = nn.Sequential( + model, + linear, + ) + elif kargs['backbone'] == 'resnet18': + resnet18 = models.resnet18(pretrained=True) + modules = list(resnet18.children())[:-1] + # model_ft.fc = nn.Linear(512, out_features) + + res18 = nn.Sequential( + *modules, + + ) + for param in res18.parameters(): + param.requires_grad = False + self.model_ft = nn.Sequential( + res18, + nn.AdaptiveAvgPool2d(1), + View((-1, 512)), + nn.Linear(512, self.out_features), + nn.ReLU(), + ) + elif kargs['backbone'] == 'resnet50': + + resnet50 = models.resnet50(pretrained=True) + # model_ft.fc = nn.Linear(1024, out_features) + modules = list(resnet50.children())[:-3] + res50 = nn.Sequential( + *modules, + ) + for param in res50.parameters(): + param.requires_grad = False + self.model_ft = nn.Sequential( + res50, + nn.AdaptiveAvgPool2d(1), + View((-1, 1024)), + nn.Linear(1024, self.out_features), + nn.ReLU() + ) + elif kargs['backbone'] == 'simple': #mil-ab attention + feature_extracting = False + self.model_ft = nn.Sequential( + nn.Conv2d(3, 20, kernel_size=5), + nn.ReLU(), + nn.MaxPool2d(2, stride=2), + nn.Conv2d(20, 50, kernel_size=5), + nn.ReLU(), + nn.MaxPool2d(2, stride=2), + View((-1, 1024)), + nn.Linear(1024, self.out_features), + nn.ReLU(), + ) #---->remove v_num - def get_progress_bar_dict(self): - # don't show the version number - items = super().get_progress_bar_dict() - items.pop("v_num", None) - return items + # def get_progress_bar_dict(self): + # # don't show the version number + # items = super().get_progress_bar_dict() + # items.pop("v_num", None) + # return items def training_step(self, batch, batch_idx): #---->inference - data, label = batch - results_dict = self.model(data=data, label=label) + data, label, _ = batch + label = label.float() + data = data.squeeze(0) + # print(data.shape) + features = self.model_ft(data) + + features = features.unsqueeze(0) + # print(features.shape) + # features = features.squeeze() + results_dict = self.model(data=features) + # results_dict = self.model(data=data, label=label) logits = results_dict['logits'] Y_prob = results_dict['Y_prob'] Y_hat = results_dict['Y_hat'] @@ -87,8 +179,13 @@ class ModelInterface(pl.LightningModule): loss = self.loss(logits, label) #---->acc log + # print(label) Y_hat = int(Y_hat) - Y = int(label) + # if self.n_classes == 2: + # Y = int(label[0][1]) + # else: + Y = torch.argmax(label) + # Y = int(label[0]) self.data[Y]["count"] += 1 self.data[Y]["correct"] += (Y_hat == Y) @@ -106,19 +203,28 @@ class ModelInterface(pl.LightningModule): self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] def validation_step(self, batch, batch_idx): - data, label = batch - results_dict = self.model(data=data, label=label) + + data, label, _ = batch + + label = label.float() + data = data.squeeze(0) + features = self.model_ft(data) + features = features.unsqueeze(0) + + results_dict = self.model(data=features) logits = results_dict['logits'] Y_prob = results_dict['Y_prob'] Y_hat = results_dict['Y_hat'] #---->acc log - Y = int(label) + # Y = int(label[0][1]) + Y = torch.argmax(label) + self.data[Y]["count"] += 1 self.data[Y]["correct"] += (Y_hat.item() == Y) - return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label} + return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y} def validation_epoch_end(self, val_step_outputs): @@ -126,13 +232,26 @@ class ModelInterface(pl.LightningModule): probs = torch.cat([x['Y_prob'] for x in val_step_outputs], dim = 0) max_probs = torch.stack([x['Y_hat'] for x in val_step_outputs]) target = torch.stack([x['label'] for x in val_step_outputs], dim = 0) - #----> + # logits = logits.long() + # target = target.squeeze().long() + # logits = logits.squeeze(0) self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True) self.log('auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True) - self.log_dict(self.valid_metrics(max_probs.squeeze() , target.squeeze()), + + # print(max_probs.squeeze(0).shape) + # print(target.shape) + self.log_dict(self.valid_metrics(max_probs.squeeze() , target), on_epoch = True, logger = True) + #----> log confusion matrix + confmat = self.confusion_matrix(max_probs.squeeze(), target) + df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes)) + plt.figure() + fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure() + plt.close(fig_) + self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch) + #---->acc log for c in range(self.n_classes): count = self.data[c]["count"] @@ -156,18 +275,24 @@ class ModelInterface(pl.LightningModule): return [optimizer] def test_step(self, batch, batch_idx): - data, label = batch - results_dict = self.model(data=data, label=label) + + data, label, _ = batch + label = label.float() + data = data.squeeze(0) + features = self.model_ft(data) + features = features.unsqueeze(0) + + results_dict = self.model(data=features, label=label) logits = results_dict['logits'] Y_prob = results_dict['Y_prob'] Y_hat = results_dict['Y_hat'] #---->acc log - Y = int(label) + Y = torch.argmax(label) self.data[Y]["count"] += 1 self.data[Y]["correct"] += (Y_hat.item() == Y) - return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label} + return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y} def test_epoch_end(self, output_results): probs = torch.cat([x['Y_prob'] for x in output_results], dim = 0) @@ -176,12 +301,20 @@ class ModelInterface(pl.LightningModule): #----> auc = self.AUROC(probs, target.squeeze()) - metrics = self.test_metrics(max_probs.squeeze() , target.squeeze()) - metrics['auc'] = auc + metrics = self.test_metrics(max_probs.squeeze() , target) + + + # metrics = self.test_metrics(max_probs.squeeze() , torch.argmax(target.squeeze(), dim=1)) + metrics['test_auc'] = auc + + # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True) + + # print(max_probs.squeeze(0).shape) + # print(target.shape) + # self.log_dict(metrics, logger = True) for keys, values in metrics.items(): print(f'{keys} = {values}') metrics[keys] = values.cpu().numpy() - print() #---->acc log for c in range(self.n_classes): count = self.data[c]["count"] @@ -192,6 +325,16 @@ class ModelInterface(pl.LightningModule): acc = float(correct) / count print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count)) self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] + + confmat = self.confusion_matrix(max_probs.squeeze(), target) + df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes)) + plt.figure() + fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure() + # plt.close(fig_) + # self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch) + plt.savefig(f'{self.log_path}/cm_test') + plt.close(fig_) + #----> result = pd.DataFrame([metrics]) result.to_csv(self.log_path / 'result.csv') @@ -226,4 +369,18 @@ class ModelInterface(pl.LightningModule): if arg in inkeys: args1[arg] = getattr(self.hparams.model, arg) args1.update(other_args) - return Model(**args1) \ No newline at end of file + return Model(**args1) + +class View(nn.Module): + def __init__(self, shape): + super().__init__() + self.shape = shape + + def forward(self, input): + ''' + Reshapes the input according to the shape saved in the view data structure. + ''' + # batch_size = input.size(0) + # shape = (batch_size, *self.shape) + out = input.view(*self.shape) + return out \ No newline at end of file diff --git a/models/vision_transformer.py b/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ebffe7a868b547806ff19e30290729f9688cc0fa --- /dev/null +++ b/models/vision_transformer.py @@ -0,0 +1,330 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Mostly copy-paste from timm library. +https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +""" +import math +from functools import partial + +import torch +import torch.nn as nn + +# from utils import trunc_normal_ + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + +def drop_path(x, drop_prob: float = 0., training: bool = False): + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, return_attention=False): + y, attn = self.attn(self.norm1(x)) + if return_attention: + return attn + x = x + self.drop_path(y) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + num_patches = (img_size // patch_size) * (img_size // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer """ + def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): + super().__init__() + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = PatchEmbed( + img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Classifier head + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def interpolate_pos_encoding(self, x, w, h): + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + class_pos_embed = self.pos_embed[:, 0] + patch_pos_embed = self.pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_embed.patch_size + h0 = h // self.patch_embed.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode='bicubic', + ) + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def prepare_tokens(self, x): + B, nc, w, h = x.shape + x = self.patch_embed(x) # patch linear embedding + + # add the [CLS] token to the embed patch tokens + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # add positional encoding to each token + x = x + self.interpolate_pos_encoding(x, w, h) + + return self.pos_drop(x) + + def forward(self, x): + x = self.prepare_tokens(x) + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + return x[:, 0] + + def get_last_selfattention(self, x): + x = self.prepare_tokens(x) + for i, blk in enumerate(self.blocks): + if i < len(self.blocks) - 1: + x = blk(x) + else: + # return attention of the last block + return blk(x, return_attention=True) + + def get_intermediate_layers(self, x, n=1): + x = self.prepare_tokens(x) + # we return the output tokens from the `n` last blocks + output = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if len(self.blocks) - i <= n: + output.append(self.norm(x)) + return output + + +def vit_tiny(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_small(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, #num_heads=6 + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_base(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +class DINOHead(nn.Module): + def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): + super().__init__() + nlayers = max(nlayers, 1) + if nlayers == 1: + self.mlp = nn.Linear(in_dim, bottleneck_dim) + else: + layers = [nn.Linear(in_dim, hidden_dim)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim)) + self.mlp = nn.Sequential(*layers) + self.apply(self._init_weights) + self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + if norm_last_layer: + self.last_layer.weight_g.requires_grad = False + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + x = nn.functional.normalize(x, dim=-1, p=2) + x = self.last_layer(x) + return x diff --git a/train.py b/train.py index c4b4d7d7c27fdeacbfcd9b56609557f5beed8372..182e3c0f4f0df1096166441731e39794f8599766 100644 --- a/train.py +++ b/train.py @@ -3,8 +3,9 @@ from pathlib import Path import numpy as np import glob -from datasets import DataInterface -from models import ModelInterface +from datasets.data_interface import DataInterface, MILDataModule +from models.model_interface import ModelInterface +import models.vision_transformer as vits from utils.utils import * # pytorch_lightning @@ -15,8 +16,8 @@ from pytorch_lightning import Trainer def make_parse(): parser = argparse.ArgumentParser() parser.add_argument('--stage', default='train', type=str) - parser.add_argument('--config', default='Camelyon/TransMIL.yaml',type=str) - parser.add_argument('--gpus', default = [2]) + parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str) + # parser.add_argument('--gpus', default = [2]) parser.add_argument('--fold', default = 0) args = parser.parse_args() return args @@ -34,20 +35,31 @@ def main(cfg): cfg.callbacks = load_callbacks(cfg) #---->Define Data - DataInterface_dict = {'train_batch_size': cfg.Data.train_dataloader.batch_size, - 'train_num_workers': cfg.Data.train_dataloader.num_workers, - 'test_batch_size': cfg.Data.test_dataloader.batch_size, - 'test_num_workers': cfg.Data.test_dataloader.num_workers, - 'dataset_name': cfg.Data.dataset_name, - 'dataset_cfg': cfg.Data,} - dm = DataInterface(**DataInterface_dict) + # DataInterface_dict = {'train_batch_size': cfg.Data.train_dataloader.batch_size, + # 'train_num_workers': cfg.Data.train_dataloader.num_workers, + # 'test_batch_size': cfg.Data.test_dataloader.batch_size, + # 'test_num_workers': cfg.Data.test_dataloader.num_workers, + # 'dataset_name': cfg.Data.dataset_name, + # 'dataset_cfg': cfg.Data,} + # dm = DataInterface(**DataInterface_dict) + home = Path.cwd().parts[1] + DataInterface_dict = { + 'data_root': cfg.Data.data_dir, + 'label_path': cfg.Data.label_file, + 'batch_size': cfg.Data.train_dataloader.batch_size, + 'num_workers': cfg.Data.train_dataloader.num_workers, + 'n_classes': cfg.Model.n_classes, + } + dm = MILDataModule(**DataInterface_dict) + #---->Define Model ModelInterface_dict = {'model': cfg.Model, 'loss': cfg.Loss, 'optimizer': cfg.Optimizer, 'data': cfg.Data, - 'log': cfg.log_path + 'log': cfg.log_path, + 'backbone': cfg.Model.backbone, } model = ModelInterface(**ModelInterface_dict) @@ -57,12 +69,18 @@ def main(cfg): logger=cfg.load_loggers, callbacks=cfg.callbacks, max_epochs= cfg.General.epochs, + min_epochs = 200, gpus=cfg.General.gpus, - amp_level=cfg.General.amp_level, + # gpus = [4], + # strategy='ddp', + amp_backend='native', + # amp_level=cfg.General.amp_level, precision=cfg.General.precision, accumulate_grad_batches=cfg.General.grad_acc, - deterministic=True, - check_val_every_n_epoch=1, + # fast_dev_run = True, + + # deterministic=True, + check_val_every_n_epoch=10, ) #---->train or test @@ -83,7 +101,7 @@ if __name__ == '__main__': #---->update cfg.config = args.config - cfg.General.gpus = args.gpus + # cfg.General.gpus = args.gpus cfg.General.server = args.stage cfg.Data.fold = args.fold diff --git a/utils/__pycache__/__init__.cpython-39.pyc b/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dcd009c36ba1c22656489be43171004bc84df85 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/utils/__pycache__/utils.cpython-39.pyc b/utils/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1eeb958497b23b1f9e9934764083531720a66d0 Binary files /dev/null and b/utils/__pycache__/utils.cpython-39.pyc differ diff --git a/utils/utils.py b/utils/utils.py index 1b7e44f8b1fd69860ebaa3483688aac78a52a7ba..96ed223d1f73fb9afbf951486376bc02c76ae75a 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -14,7 +14,7 @@ def load_loggers(cfg): log_path = cfg.General.log_path Path(log_path).mkdir(exist_ok=True, parents=True) - log_name = Path(cfg.config).parent + log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}' version_name = Path(cfg.config).name[:-5] cfg.log_path = Path(log_path) / log_name / version_name / f'fold{cfg.Data.fold}' print(f'---->Log dir: {cfg.log_path}') @@ -31,8 +31,10 @@ def load_loggers(cfg): #---->load Callback -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar +from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme from pytorch_lightning.callbacks.early_stopping import EarlyStopping + def load_callbacks(cfg): Mycallbacks = [] @@ -47,7 +49,21 @@ def load_callbacks(cfg): verbose=True, mode='min' ) + Mycallbacks.append(early_stop_callback) + progress_bar = RichProgressBar( + theme=RichProgressBarTheme( + description='green_yellow', + progress_bar='green1', + progress_bar_finished='green1', + batch_progress='green_yellow', + time='grey82', + processing_speed='grey82', + metrics='grey82' + + ) + ) + Mycallbacks.append(progress_bar) if cfg.General.server == 'train' : Mycallbacks.append(ModelCheckpoint(monitor = 'val_loss', @@ -64,7 +80,7 @@ def load_callbacks(cfg): import torch import torch.nn.functional as F def cross_entropy_torch(x, y): - x_softmax = [F.softmax(x[i]) for i in range(len(x))] - x_log = torch.tensor([torch.log(x_softmax[i][y[i]]) for i in range(len(y))]) - loss = - torch.sum(x_log) / len(y) + x_softmax = [F.softmax(x[i], dim=0) for i in range(len(x))] + x_log = torch.tensor([torch.log(x_softmax[i][y[i]]) for i in range(y.shape[0])]) + loss = - torch.sum(x_log) / y.shape[0] return loss