diff --git a/DeepGraft/TransMIL_simple.yaml b/DeepGraft/Resnet50.yaml similarity index 93% rename from DeepGraft/TransMIL_simple.yaml rename to DeepGraft/Resnet50.yaml index 4501f2c06242bb6cb951fef458eab000daac0d75..e6b780b18947b315ee77bb0c6b8b0dbd433fa249 100644 --- a/DeepGraft/TransMIL_simple.yaml +++ b/DeepGraft/Resnet50.yaml @@ -5,7 +5,7 @@ General: amp_level: O2 precision: 16 multi_gpu_mode: dp - gpus: [1] + gpus: [0] epochs: &epoch 200 grad_acc: 2 frozen_bn: False @@ -32,9 +32,8 @@ Data: Model: - name: TransMIL + name: resnet50 n_classes: 2 - backbone: simple Optimizer: diff --git a/DeepGraft/TransMIL_debug.yaml b/DeepGraft/TransMIL_debug.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d83ce0d902228644d4fdd8a3059cd4b135a69fde --- /dev/null +++ b/DeepGraft/TransMIL_debug.yaml @@ -0,0 +1,50 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [1] + epochs: &epoch 200 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_debug.json' + fold: 0 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + + + +Model: + name: TransMIL + n_classes: 2 + backbone: resnet18 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_dino.yaml b/DeepGraft/TransMIL_dino.yaml index ffe987c913ff9075730d7b16b8e9dba16d5d4978..b7161ba63b98a38a09a8a545e0f661b13b342914 100644 --- a/DeepGraft/TransMIL_dino.yaml +++ b/DeepGraft/TransMIL_dino.yaml @@ -6,10 +6,10 @@ General: precision: 16 multi_gpu_mode: dp gpus: [0] - epochs: &epoch 1000 + epochs: &epoch 200 grad_acc: 2 frozen_bn: False - patience: 200 + patience: 50 server: test #train #test log_path: logs/ @@ -17,7 +17,7 @@ Data: dataset_name: custom data_shuffle: False data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' - label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_bin.json' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json' fold: 0 nfold: 4 diff --git a/DeepGraft/TransMIL_efficientnet_no_other.yaml b/DeepGraft/TransMIL_efficientnet_no_other.yaml new file mode 100644 index 0000000000000000000000000000000000000000..79d8ea88865ebddd6a918bdc4b9435b6ba973ff1 --- /dev/null +++ b/DeepGraft/TransMIL_efficientnet_no_other.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [0] + epochs: &epoch 1000 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_other.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 5 + backbone: efficientnet + + +Optimizer: + opt: lookahead_radam + lr: 0.0001 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_efficientnet_no_viral.yaml b/DeepGraft/TransMIL_efficientnet_no_viral.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8780060ebd1475b273b06498800523d7108085b1 --- /dev/null +++ b/DeepGraft/TransMIL_efficientnet_no_viral.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 4 + backbone: efficientnet + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f69b5bfadaa3023c9b6113e59681daee73b29fe2 --- /dev/null +++ b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: train #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 2 + backbone: efficientnet + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_resnet18_all.yaml b/DeepGraft/TransMIL_resnet18_all.yaml index 8fa5818981b31c1a6c36f34b42e055f2198681fc..f331e4ee092dd8a736e56ff1737514dffd2d1ad2 100644 --- a/DeepGraft/TransMIL_resnet18_all.yaml +++ b/DeepGraft/TransMIL_resnet18_all.yaml @@ -9,7 +9,7 @@ General: epochs: &epoch 500 grad_acc: 2 frozen_bn: False - patience: 200 + patience: 50 server: test #train #test log_path: logs/ diff --git a/DeepGraft/TransMIL_resnet18_no_other.yaml b/DeepGraft/TransMIL_resnet18_no_other.yaml index 95a9bd64692f5ee12dd822e04fea889b80717457..c7a27f228ec2cbeb77a9456af063f8868c9e877f 100644 --- a/DeepGraft/TransMIL_resnet18_no_other.yaml +++ b/DeepGraft/TransMIL_resnet18_no_other.yaml @@ -9,7 +9,7 @@ General: epochs: &epoch 500 grad_acc: 2 frozen_bn: False - patience: 200 + patience: 50 server: test #train #test log_path: logs/ diff --git a/DeepGraft/TransMIL_resnet18_no_viral.yaml b/DeepGraft/TransMIL_resnet18_no_viral.yaml index 155b676a24e0541f42f8fee12af2d987217c0525..93054bf660eaecfdc7fd8e9ac7ceecf8007b12b9 100644 --- a/DeepGraft/TransMIL_resnet18_no_viral.yaml +++ b/DeepGraft/TransMIL_resnet18_no_viral.yaml @@ -9,7 +9,7 @@ General: epochs: &epoch 500 grad_acc: 2 frozen_bn: False - patience: 200 + patience: 50 server: test #train #test log_path: logs/ diff --git a/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml index e7d9bf0694a227f987d1d2fbf2f0facb53c248d5..c26e1e9ff0329f32efbfd96334c6d1ac957d90bb 100644 --- a/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml +++ b/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml @@ -9,7 +9,7 @@ General: epochs: &epoch 500 grad_acc: 2 frozen_bn: False - patience: 200 + patience: 50 server: test #train #test log_path: logs/ diff --git a/DeepGraft/TransMIL_resnet50_all.yaml b/DeepGraft/TransMIL_resnet50_all.yaml index eba3a4fa20870ad4fc2b173ccb4e60086ddb3ac5..e6959eac839447b10e1a89ef3f09001ab4318594 100644 --- a/DeepGraft/TransMIL_resnet50_all.yaml +++ b/DeepGraft/TransMIL_resnet50_all.yaml @@ -9,7 +9,7 @@ General: epochs: &epoch 1000 grad_acc: 2 frozen_bn: False - patience: 200 + patience: 50 server: train #train #test log_path: logs/ diff --git a/DeepGraft/TransMIL_resnet50_no_other.yaml b/DeepGraft/TransMIL_resnet50_no_other.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d3cd2aa27224aeecd541edff40643dda60b4609e --- /dev/null +++ b/DeepGraft/TransMIL_resnet50_no_other.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [0] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 50 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_other.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 5 + backbone: resnet50 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_resnet50_no_viral.yaml b/DeepGraft/TransMIL_resnet50_no_viral.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2e3394a8342caab9824528b5ec254bc6c91b7ba3 --- /dev/null +++ b/DeepGraft/TransMIL_resnet50_no_viral.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 50 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 4 + backbone: resnet50 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a756616cd4a128874731b17353d108856f72e9f4 --- /dev/null +++ b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 50 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 2 + backbone: resnet50 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/MyBackbone/__init__.py b/MyBackbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1fca298edf4ac7bdfb723bcaddf484dad284ae25 --- /dev/null +++ b/MyBackbone/__init__.py @@ -0,0 +1,2 @@ + +from .backbone_factory import init_backbone \ No newline at end of file diff --git a/MyBackbone/backbone_factory.py b/MyBackbone/backbone_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..ff770e583fc3ddc4712424f9381ed9adeb8b7742 --- /dev/null +++ b/MyBackbone/backbone_factory.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn + +from transformers import AutoFeatureExtractor, ViTModel +from torchvision import models + +def init_backbone(**kargs): + + backbone = kargs['backbone'] + n_classes = kargs['n_classes'] + out_features = kargs['out_features'] + + if backbone == 'dino' or backbone == 'vit': + + if backbone == 'dino': + feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16') + model_ft = ViTModel.from_pretrained('facebook/dino-vitb16', num_labels=n_classes) + + def model_ft(input): + input = feature_extractor(input, return_tensors='pt') + features = model_ft(**input) + + + elif kargs['backbone'] == 'resnet18': + resnet18 = models.resnet18(pretrained=True) + modules = list(resnet18.children())[:-1] + # model_ft.fc = nn.Linear(512, out_features) + + res18 = nn.Sequential( + *modules, + ) + for param in res18.parameters(): + param.requires_grad = False + self.model_ft = nn.Sequential( + res18, + nn.AdaptiveAvgPool2d(1), + View((-1, 512)), + nn.Linear(512, self.out_features), + nn.GELU(), + ) + elif kargs['backbone'] == 'resnet50': + + resnet50 = models.resnet50(pretrained=True) + # model_ft.fc = nn.Linear(1024, out_features) + modules = list(resnet50.children())[:-3] + res50 = nn.Sequential( + *modules, + ) + for param in res50.parameters(): + param.requires_grad = False + self.model_ft = nn.Sequential( + res50, + nn.AdaptiveAvgPool2d(1), + View((-1, 1024)), + nn.Linear(1024, self.out_features), + nn.GELU() + ) + elif kargs['backbone'] == 'efficientnet': + efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_widese_b0', pretrained=True) + for param in efficientnet.parameters(): + param.requires_grad = False + self.model_ft = nn.Sequential( + efficientnet, + nn.Linear(1000, 512), + nn.GELU(), + ) + elif kargs['backbone'] == 'simple': #mil-ab attention + feature_extracting = False + self.model_ft = nn.Sequential( + nn.Conv2d(3, 20, kernel_size=5), + nn.ReLU(), + nn.MaxPool2d(2, stride=2), + nn.Conv2d(20, 50, kernel_size=5), + nn.ReLU(), + nn.MaxPool2d(2, stride=2), + View((-1, 1024)), + nn.Linear(1024, self.out_features), + nn.ReLU(), + ) + diff --git a/MyLoss/__pycache__/loss_factory.cpython-39.pyc b/MyLoss/__pycache__/loss_factory.cpython-39.pyc index ed7437016bbcadc48f9bb0a97401529dc574c9b2..14452dde7b34fd11b5255f4ed07bdeb48a27e49f 100644 Binary files a/MyLoss/__pycache__/loss_factory.cpython-39.pyc and b/MyLoss/__pycache__/loss_factory.cpython-39.pyc differ diff --git a/MyLoss/__pycache__/poly_loss.cpython-39.pyc b/MyLoss/__pycache__/poly_loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08edf030976f28689763861889e7ace3d3747a59 Binary files /dev/null and b/MyLoss/__pycache__/poly_loss.cpython-39.pyc differ diff --git a/MyLoss/loss_factory.py b/MyLoss/loss_factory.py index 1dffa6182ef40b1f46436b8a1eadb0a8906c17ff..f3bdcebb96480a0ed6de1e33a0de2defaf1792ea 100755 --- a/MyLoss/loss_factory.py +++ b/MyLoss/loss_factory.py @@ -13,6 +13,7 @@ from .hausdorff import HausdorffDTLoss, HausdorffERLoss from .lovasz_loss import LovaszSoftmax from .ND_Crossentropy import CrossentropyND, TopKLoss, WeightedCrossEntropyLoss,\ WeightedCrossEntropyLossV2, DisPenalizedCE +from .poly_loss import PolyLoss from pytorch_toolbelt import losses as L @@ -22,7 +23,7 @@ def create_loss(args, w1=1.0, w2=0.5): # mode = args.base_loss #BINARY_MODE \MULTICLASS_MODE \MULTILABEL_MODE loss = None if hasattr(nn, conf_loss): - loss = getattr(nn, conf_loss)() + loss = getattr(nn, conf_loss)(label_smoothing=0.5) #binary loss elif conf_loss == "focal": loss = L.BinaryFocalLoss() @@ -46,6 +47,8 @@ def create_loss(args, w1=1.0, w2=0.5): loss = L.JointLoss(BCEWithLogitsLoss(), L.BinaryDiceLogLoss(), w1, w2) elif conf_loss == "reduced_focal": loss = L.BinaryFocalLoss(reduced=True) + elif conf_loss == "polyloss": + loss = PolyLoss(softmax=False) else: assert False and "Invalid loss" raise ValueError diff --git a/MyLoss/poly_loss.py b/MyLoss/poly_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e3705458eff14c788f2d6719b923d6498cc3c74d --- /dev/null +++ b/MyLoss/poly_loss.py @@ -0,0 +1,84 @@ +# From https://github.com/yiyixuxu/polyloss-pytorch + +import warnings +from typing import Optional + +import torch +import torch.nn as nn +from torch.nn.modules.loss import _Loss + + +def to_one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor: + # if `dim` is bigger, add singleton dim at the end + if labels.ndim < dim + 1: + shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape)) + labels = torch.reshape(labels, shape) + + sh = list(labels.shape) + + if sh[dim] != 1: + raise AssertionError("labels should have a channel with length equal to one.") + + sh[dim] = num_classes + + o = torch.zeros(size=sh, dtype=dtype, device=labels.device) + labels = o.scatter_(dim=dim, index=labels.long(), value=1) + + return labels + + +class PolyLoss(_Loss): + def __init__(self, + softmax: bool = False, + ce_weight: Optional[torch.Tensor] = None, + reduction: str = 'mean', + epsilon: float = 1.0, + ) -> None: + super().__init__() + self.softmax = softmax + self.reduction = reduction + self.epsilon = epsilon + self.cross_entropy = nn.CrossEntropyLoss(weight=ce_weight, reduction='none') + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNH[WD], where N is the number of classes. + You can pass logits or probabilities as input, if pass logit, must set softmax=True + target: the shape should be BNH[WD] (one-hot format) or B1H[WD], where N is the number of classes. + It should contain binary values + Raises: + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + """ + n_pred_ch, n_target_ch = input.shape[1], target.shape[1] + # target not in one-hot encode format, has shape B1H[WD] + if n_pred_ch != n_target_ch: + # squeeze out the channel dimension of size 1 to calculate ce loss + self.ce_loss = self.cross_entropy(input, torch.squeeze(target, dim=1).long()) + # convert into one-hot format to calculate ce loss + target = to_one_hot(target, num_classes=n_pred_ch) + else: + # # target is in the one-hot format, convert to BH[WD] format to calculate ce loss + self.ce_loss = self.cross_entropy(input, torch.argmax(target, dim=1)) + + if self.softmax: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `softmax=True` ignored.") + else: + input = torch.softmax(input, 1) + + pt = (input * target).sum(dim=1) # BH[WD] + poly_loss = self.ce_loss + self.epsilon * (1 - pt) + + if self.reduction == 'mean': + polyl = torch.mean(poly_loss) # the batch and channel average + elif self.reduction == 'sum': + polyl = torch.sum(poly_loss) # sum over the batch and channel dims + elif self.reduction == 'none': + # If we are not computing voxelwise loss components at least + # make sure a none reduction maintains a broadcastable shape + # BH[WD] -> BH1[WD] + polyl = poly_loss.unsqueeze(1) + else: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + return (polyl) \ No newline at end of file diff --git a/README.md b/README.md index 04c7dbabf81b17979326fdff17670a1b5169e596..effc0fe369c45153ec681403569c710ea114baf3 100644 --- a/README.md +++ b/README.md @@ -9,3 +9,7 @@ python train.py --stage='train' --config='Camelyon/TransMIL.yaml' --gpus=0 --fo ```python python train.py --stage='test' --config='Camelyon/TransMIL.yaml' --gpus=0 --fold=0 ``` + + +### Changes Made: + diff --git a/datasets/__pycache__/custom_dataloader.cpython-39.pyc b/datasets/__pycache__/custom_dataloader.cpython-39.pyc index 147b3fc3c4628d6eed5bebd6ff248d31aaeebb34..4a200bbb7d34328019d9bb9e604b0524127ae863 100644 Binary files a/datasets/__pycache__/custom_dataloader.cpython-39.pyc and b/datasets/__pycache__/custom_dataloader.cpython-39.pyc differ diff --git a/datasets/__pycache__/data_interface.cpython-39.pyc b/datasets/__pycache__/data_interface.cpython-39.pyc index 4af0141496c11d6f97255add8746fa0067e55636..9550db1509e8d477a1c15534a76c6f87976fbec4 100644 Binary files a/datasets/__pycache__/data_interface.cpython-39.pyc and b/datasets/__pycache__/data_interface.cpython-39.pyc differ diff --git a/datasets/custom_dataloader.py b/datasets/custom_dataloader.py index ddc2ed36c64fc46f0844673514b51cead0052bc6..02850f5db5e263a9d5cdbbfff923952dd5dbfa52 100644 --- a/datasets/custom_dataloader.py +++ b/datasets/custom_dataloader.py @@ -9,12 +9,25 @@ from torch.utils.data.dataloader import DataLoader from tqdm import tqdm # from histoTransforms import RandomHueSaturationValue import torchvision.transforms as transforms +import torchvision import torch.nn.functional as F import csv from PIL import Image import cv2 import pandas as pd import json +import albumentations as A +from albumentations.pytorch import ToTensorV2 +from transformers import AutoFeatureExtractor +from imgaug import augmenters as iaa +import imgaug as ia +from torchsampler import ImbalancedDatasetSampler + + +class RangeNormalization(object): + def __call__(self, sample): + img = sample + return (img / 255.0 - 0.5) / 0.5 class HDF5MILDataloader(data.Dataset): """Represents an abstract HDF5 dataset. For single H5 container! @@ -28,68 +41,89 @@ class HDF5MILDataloader(data.Dataset): data_cache_size: Number of HDF5 files that can be cached in the cache (default=3). """ - def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=20): + def __init__(self, file_path, label_path, mode, n_classes, backbone=None, load_data=False, data_cache_size=20, bag_size=1024): super().__init__() self.data_info = [] self.data_cache = {} self.slideLabelDict = {} + self.files = [] self.data_cache_size = data_cache_size self.mode = mode self.file_path = file_path # self.csv_path = csv_path self.label_path = label_path self.n_classes = n_classes - self.bag_size = 120 + self.bag_size = bag_size + self.backbone = backbone # self.label_file = label_path recursive = True + # read labels and slide_path from csv - - # df = pd.read_csv(self.csv_path) - # labels = df.LABEL - # slides = df.FILENAME with open(self.label_path, 'r') as f: - self.slideLabelDict = json.load(f)[mode] - - self.slideLabelDict = {Path(x).stem : y for (x,y) in self.slideLabelDict} - - - # if Path(slides[0]).suffix: - # slides = list(map(lambda x: Path(x).stem, slides)) - - # print(labels) - # print(slides) - # self.slideLabelDict = dict(zip(slides, labels)) - # print(self.slideLabelDict) - - #check if files in slideLabelDict, only take files that are available. + temp_slide_label_dict = json.load(f)[mode] + for (x, y) in temp_slide_label_dict: + x = Path(x).stem + x_complete_path = Path(self.file_path)/Path(x + '.hdf5') + if x_complete_path.is_file(): + self.slideLabelDict[x] = y + self.files.append(x_complete_path) - files_in_path = list(Path(self.file_path).rglob('*.hdf5')) - files_in_path = [x.stem for x in files_in_path] - # print(len(files_in_path)) - # print(files_in_path) - # print(list(self.slideLabelDict.keys())) - # for x in list(self.slideLabelDict.keys()): - # if x in files_in_path: - # path = Path(self.file_path) / (x + '.hdf5') - # print(path) - - self.files = [Path(self.file_path)/ (x + '.hdf5') for x in list(self.slideLabelDict.keys()) if x in files_in_path] - - print(len(self.files)) - # self.files = list(map(lambda x: Path(self.file_path) / (Path(x).stem + '.hdf5'), list(self.slideLabelDict.keys()))) for h5dataset_fp in tqdm(self.files): - # print(h5dataset_fp) self._add_data_infos(str(h5dataset_fp.resolve()), load_data) - # print(self.data_info) - self.resize_transforms = transforms.Compose([ - transforms.ToPILImage(), - transforms.Resize(256), + + self.resize_transforms = A.Compose([ + A.SmallestMaxSize(max_size=256) ]) + sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1") + sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2") + sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3") + sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4") + sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5") + + self.train_transforms = iaa.Sequential([ + iaa.AddToHueAndSaturation(value=(-13, 13), name="MyHSV"), + sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")), + iaa.Fliplr(0.5, name="MyFlipLR"), + iaa.Flipud(0.5, name="MyFlipUD"), + sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")), + iaa.OneOf([ + sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")), + sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")), + sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine")) + ], name="MyOneOf") + + ], name="MyAug") + + # self.train_transforms = A.Compose([ + # A.HueSaturationValue(hue_shift_limit=13, sat_shift_limit=2, val_shift_limit=0, always_apply=True, p=1.0), + # # A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=0, val_shift_limit=0, always_apply=False, p=0.5), + # # A.RandomGamma(), + # # A.HorizontalFlip(), + # # A.VerticalFlip(), + # # A.RandomRotate90(), + # # A.OneOf([ + # # A.ElasticTransform(alpha=150, sigma=20, alpha_affine=50), + # # A.Affine( + # # scale={'x': (0.95, 1.05), 'y': (0.95, 1.05)}, + # # rotate=(-45, 45), + # # shear=(-4, 4), + # # cval=8, + # # ) + # # ]), + # A.Normalize(), + # ToTensorV2(), + # ]) + self.val_transforms = transforms.Compose([ + # A.Normalize(), + # ToTensorV2(), + RangeNormalization(), + transforms.ToTensor(), + ]) self.img_transforms = transforms.Compose([ transforms.RandomHorizontalFlip(p=1), transforms.RandomVerticalFlip(p=1), @@ -100,32 +134,45 @@ class HDF5MILDataloader(data.Dataset): RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)), transforms.ToTensor() ]) + if self.backbone == 'dino': + self.feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16') # self._add_data_infos(load_data) - def __getitem__(self, index): # get data batch, label, name = self.get_data(index) out_batch = [] + seq_img_d = self.train_transforms.to_deterministic() if self.mode == 'train': # print(img) # print(img.shape) - for img in batch: - img = self.img_transforms(img) - img = self.hsv_transforms(img) + for img in batch: # expects numpy + img = img.numpy().astype(np.uint8) + if self.backbone == 'dino': + img = self.feature_extractor(images=img, return_tensors='pt') + # img = self.resize_transforms(img) + img = seq_img_d.augment_image(img) + img = self.val_transforms(img) out_batch.append(img) else: for img in batch: - img = transforms.functional.to_tensor(img) + img = img.numpy().astype(np.uint8) + if self.backbone == 'dino': + img = self.feature_extractor(images=img, return_tensors='pt') + img = self.resize_transforms(img) + + img = self.val_transforms(img) out_batch.append(img) + if len(out_batch) == 0: # print(name) - out_batch = torch.randn(100,3,256,256) + out_batch = torch.randn(self.bag_size,3,256,256) else: out_batch = torch.stack(out_batch) - out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch + # print(out_batch.shape) + # out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch label = torch.as_tensor(label) label = torch.nn.functional.one_hot(label, num_classes=self.n_classes) @@ -138,29 +185,7 @@ class HDF5MILDataloader(data.Dataset): wsi_name = Path(file_path).stem if wsi_name in self.slideLabelDict: label = self.slideLabelDict[wsi_name] - wsi_batch = [] - # with h5py.File(file_path, 'r') as h5_file: - # numKeys = len(h5_file.keys()) - # sample = list(h5_file.keys())[0] - # shape = (numKeys,) + h5_file[sample][:].shape - # for tile in h5_file.keys(): - # img = h5_file[tile][:] - - # print(img) - # if type == 'images': - # t = 'data' - # else: - # t = 'label' idx = -1 - # if load_data: - # for tile in h5_file.keys(): - # img = h5_file[tile][:] - # img = img.astype(np.uint8) - # img = self.resize_transforms(img) - # wsi_batch.append(img) - # idx = self._add_to_cache(wsi_batch, file_path) - # wsi_batch.append(img) - # self.data_info.append({'data_path': file_path, 'label': label, 'shape': shape, 'name': wsi_name, 'cache_idx': idx}) self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx}) def _load_data(self, file_path): @@ -173,30 +198,15 @@ class HDF5MILDataloader(data.Dataset): for tile in h5_file.keys(): img = h5_file[tile][:] img = img.astype(np.uint8) - img = self.resize_transforms(img) + img = torch.from_numpy(img) + # img = self.resize_transforms(img) wsi_batch.append(img) + wsi_batch = torch.stack(wsi_batch) + wsi_batch, _ = to_fixed_size_bag(wsi_batch, self.bag_size) idx = self._add_to_cache(wsi_batch, file_path) file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path) self.data_info[file_idx + idx]['cache_idx'] = idx - # for type in ['images', 'labels']: - # for key in tqdm(h5_file[f'{self.mode}/{type}'].keys()): - # img = h5_file[data_path][:] - # idx = self._add_to_cache(img, data_path) - # file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == data_path) - # self.data_info[file_idx + idx]['cache_idx'] = idx - # for gname, group in h5_file.items(): - # for dname, ds in group.items(): - # # add data to the data cache and retrieve - # # the cache index - # idx = self._add_to_cache(ds.value, file_path) - - # # find the beginning index of the hdf5 file we are looking for - # file_idx = next(i for i,v in enumerate(self.data_info) if v['file_path'] == file_path) - - # # the data info should have the same index since we loaded it in the same way - # self.data_info[file_idx + idx]['cache_idx'] = idx - # remove an element from data cache if size was exceeded if len(self.data_cache) > self.data_cache_size: # remove one item from the cache at random @@ -223,6 +233,182 @@ class HDF5MILDataloader(data.Dataset): # data_info_type = [di for di in self.data_info if di['type'] == type] # return data_info_type + def get_name(self, i): + # name = self.get_data_infos(type)[i]['name'] + name = self.data_info[i]['name'] + return name + + def get_labels(self, indices): + + return [self.data_info[i]['label'] for i in indices] + + def get_data(self, i): + """Call this function anytime you want to access a chunk of data from the + dataset. This will make sure that the data is loaded in case it is + not part of the data cache. + i = index + """ + # fp = self.get_data_infos(type)[i]['data_path'] + fp = self.data_info[i]['data_path'] + if fp not in self.data_cache: + self._load_data(fp) + + # get new cache_idx assigned by _load_data_info + # cache_idx = self.get_data_infos(type)[i]['cache_idx'] + cache_idx = self.data_info[i]['cache_idx'] + label = self.data_info[i]['label'] + name = self.data_info[i]['name'] + # print(self.data_cache[fp][cache_idx]) + return self.data_cache[fp][cache_idx], label, name + +class GenesisDataloader(data.Dataset): + """Represents an abstract HDF5 dataset. For single H5 container! + + Input params: + file_path: Path to the folder containing the dataset (one or multiple HDF5 files). + mode: 'train' or 'test' + load_data: If True, loads all the data immediately into RAM. Use this if + the dataset is fits into memory. Otherwise, leave this at false and + the data will load lazily. + data_cache_size: Number of HDF5 files that can be cached in the cache (default=3). + + """ + def __init__(self, file_path, label_path, mode, load_data=False, data_cache_size=5, debug=False): + super().__init__() + + self.data_info = [] + self.data_cache = {} + self.slideLabelDict = {} + self.files = [] + self.data_cache_size = data_cache_size + self.mode = mode + self.file_path = file_path + # self.csv_path = csv_path + self.label_path = label_path + # self.n_classes = n_classes + self.bag_size = 120 + # self.transforms = transforms + self.input_size = 256 + + # for + # self.files = list(Path(self.file_path).rglob('*.hdf5')) + home = Path.cwd().parts[1] + with open(self.label_path, 'r') as f: + temp_slide_label_dict = json.load(f)[mode] + for x in temp_slide_label_dict: + + if Path(x).parts[1] != home: + x = x.replace(Path(x).parts[1], home) + self.files.append(Path(x)) + # x = Path(x).stem + # x_complete_path = Path(self.file_path)/Path(x + '.hdf5') + # if x_complete_path.is_file(): + # self.slideLabelDict[x] = y + # self.files.append(x_complete_path) + + for h5dataset_fp in tqdm(self.files): + self._add_data_infos(str(h5dataset_fp.resolve()), load_data) + + self.resize_transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToPILImage(), + torchvision.transforms.Resize(self.input_size), + ]) + self.resize_transforms = A.Compose([ + A.SmallestMaxSize(max_size=256) + ]) + sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1") + sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2") + sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3") + sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4") + sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5") + + self.train_transforms = iaa.Sequential([ + iaa.AddToHueAndSaturation(value=(-13, 13), name="MyHSV"), + sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")), + iaa.Fliplr(0.5, name="MyFlipLR"), + iaa.Flipud(0.5, name="MyFlipUD"), + sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")), + iaa.OneOf([ + sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")), + sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")), + sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine")) + ], name="MyOneOf") + + ], name="MyAug") + self.val_transforms = transforms.Compose([ + # A.Normalize(), + # ToTensorV2(), + RangeNormalization(), + transforms.ToTensor(), + + ]) + + + def __getitem__(self, index): + # get data + img, name = self.get_data(index) + # out_batch = [] + seq_img_d = self.train_transforms.to_deterministic() + + if self.mode == 'train': + + img = img.numpy().astype(np.uint8) + img = seq_img_d.augment_image(img) + img = self.val_transforms(img) + else: + img = img.numpy().astype(np.uint8) + img = self.val_transforms(img) + # out_batch.append(img) + + return {'data': img, 'label': label} + # return out_batch, label + + def __len__(self): + return len(self.data_info) + + def _add_data_infos(self, file_path, load_data): + img_name = Path(file_path).stem + label = Path(file_path).parts[-2] + + idx = -1 + self.data_info.append({'data_path': file_path, 'label': label, 'name': img_name, 'cache_idx': idx}) + + def _load_data(self, file_path): + """Load data to the cache given the file + path and update the cache index in the + data_info structure. + """ + with h5py.File(file_path, 'r') as h5_file: + + tile = list(h5_file.keys())[0] + img = h5_file[tile][:] + img = img.astype(np.uint8) + # img = self.resize_transforms(img) + # wsi_batch.append(img) + idx = self._add_to_cache(img, file_path) + file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path) + self.data_info[file_idx + idx]['cache_idx'] = idx + + # remove an element from data cache if size was exceeded + if len(self.data_cache) > self.data_cache_size: + # remove one item from the cache at random + removal_keys = list(self.data_cache) + removal_keys.remove(file_path) + self.data_cache.pop(removal_keys[0]) + # remove invalid cache_idx + # self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'shape': di['shape'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info] + self.data_info = [{'data_path': di['data_path'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info] + + def _add_to_cache(self, data, data_path): + """Adds data to the cache and returns its index. There is one cache + list for every file_path, containing all datasets in that file. + """ + if data_path not in self.data_cache: + self.data_cache[data_path] = [data] + else: + self.data_cache[data_path].append(data) + return len(self.data_cache[data_path]) - 1 + def get_name(self, i): # name = self.get_data_infos(type)[i]['name'] name = self.data_info[i]['name'] @@ -248,6 +434,45 @@ class HDF5MILDataloader(data.Dataset): return self.data_cache[fp][cache_idx], label, name +class RandomHueSaturationValue(object): + + def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5): + + self.hue_shift_limit = hue_shift_limit + self.sat_shift_limit = sat_shift_limit + self.val_shift_limit = val_shift_limit + self.p = p + + def __call__(self, sample): + + img = sample #,lbl + + if np.random.random() < self.p: + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32 + h, s, v = cv2.split(img) + hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1) + hue_shift = np.uint8(hue_shift) + h += hue_shift + sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1]) + s = cv2.add(s, sat_shift) + val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1]) + v = cv2.add(v, val_shift) + img = cv2.merge((h, s, v)) + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) + return img #, lbl + +def to_fixed_size_bag(bag: torch.Tensor, bag_size: int = 512): + + # get up to bag_size elements + bag_idxs = torch.randperm(bag.shape[0])[:bag_size] + bag_samples = bag[bag_idxs] + + # zero-pad if we don't have enough samples + zero_padded = torch.cat((bag_samples, + torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3]))) + return zero_padded, min(bag_size, len(bag)) + + class RandomHueSaturationValue(object): def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5): @@ -285,48 +510,81 @@ if __name__ == '__main__': train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv' data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/' # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json' - label_path = f'/{home}/ylan/DeepGraft/training_tables/split_Aachen_PAS_all.json' - output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/' - # os.makedirs(output_path, exist_ok=True) - - - dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=6) - data = DataLoader(dataset, batch_size=1) + label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json' + output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/augments' + os.makedirs(output_path, exist_ok=True) + + n_classes = 2 + + dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes, bag_size=20) + # print(dataset.dataset) + a = int(len(dataset)* 0.8) + b = int(len(dataset) - a) + train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b]) + dl = DataLoader(train_ds, None, sampler=ImbalancedDatasetSampler(train_ds), num_workers=5) + dl = DataLoader(train_ds, None, num_workers=5) + + # data = DataLoader(dataset, batch_size=1) # print(len(dataset)) - x = 0 + # # x = 0 c = 0 - for item in data: - if c >=10: - break + label_count = [0] *n_classes + for item in dl: + # if c >=10: + # break bag, label, name = item - print(bag) - # # print(bag.shape) - # if bag.shape[1] == 1: - # print(name) - # print(bag.shape) + label_count[np.argmax(label)] += 1 + print(label_count) + print(len(train_ds)) + # # # print(bag.shape) + # # if bag.shape[1] == 1: + # # print(name) + # # print(bag.shape) # print(bag.shape) - # print(name) - # out_dir = Path(output_path) / name - # os.makedirs(out_dir, exist_ok=True) - - # # print(item[2]) - # # print(len(item)) - # # print(item[1]) - # # print(data.shape) - # # data = data.squeeze() - # bag = item[0] - # bag = bag.squeeze() - # for i in range(bag.shape[0]): - # img = bag[i, :, :, :] - # img = img.squeeze() - # img = img*255 - # img = img.numpy().astype(np.uint8).transpose(1,2,0) + + # # out_dir = Path(output_path) / name + # # os.makedirs(out_dir, exist_ok=True) + + # # # print(item[2]) + # # # print(len(item)) + # # # print(item[1]) + # # # print(data.shape) + # # # data = data.squeeze() + # # bag = item[0] + # bag = bag.squeeze() + # original = original.squeeze() + # for i in range(bag.shape[0]): + # img = bag[i, :, :, :] + # img = img.squeeze() + + # img = ((img-img.min())/(img.max() - img.min())) * 255 + # print(img) + # # print(img) + # img = img.numpy().astype(np.uint8).transpose(1,2,0) + + + # img = Image.fromarray(img) + # img = img.convert('RGB') + # img.save(f'{output_path}/{i}.png') + + - # img = Image.fromarray(img) - # img = img.convert('RGB') - # img.save(f'{out_dir}/{i}.png') - c += 1 + # o_img = original[i,:,:,:] + # o_img = o_img.squeeze() + # print(o_img.shape) + # o_img = ((o_img-o_img.min())/(o_img.max()-o_img.min()))*255 + # o_img = o_img.numpy().astype(np.uint8).transpose(1,2,0) + # o_img = Image.fromarray(o_img) + # o_img = o_img.convert('RGB') + # o_img.save(f'{output_path}/{i}_original.png') + # c += 1 + # break # else: break # print(data.shape) - # print(label) \ No newline at end of file + # print(label) + # a = [torch.Tensor((3,256,256))]*3 + # b = torch.stack(a) + # print(b) + # c = to_fixed_size_bag(b, 512) + # print(c) \ No newline at end of file diff --git a/datasets/data_interface.py b/datasets/data_interface.py index 12a0f8c450b4945658d3566cc5c157116bccab66..056e6ff09bcf768c91e0f51e44640487ea4fbee1 100644 --- a/datasets/data_interface.py +++ b/datasets/data_interface.py @@ -8,6 +8,8 @@ from torchvision import transforms from .camel_dataloader import FeatureBagLoader from .custom_dataloader import HDF5MILDataloader from pathlib import Path +from transformers import AutoFeatureExtractor +from torchsampler import ImbalancedDatasetSampler class DataInterface(pl.LightningDataModule): @@ -56,9 +58,10 @@ class DataInterface(pl.LightningDataModule): train=True) a = int(len(dataset)* 0.8) b = int(len(dataset) - a) - print(a) - print(b) - self.train_dataset, self.val_dataset = random_split(dataset, [a, b]) + # print(a) + # print(b) + self.train_dataset, self.val_dataset = random_split(dataset, [a, b]) # returns data.Subset + # self.train_dataset = self.instancialize(state='train') # self.val_dataset = self.instancialize(state='val') @@ -72,7 +75,7 @@ class DataInterface(pl.LightningDataModule): def train_dataloader(self): - return DataLoader(self.train_dataset, batch_size=self.train_batch_size, num_workers=self.train_num_workers, shuffle=True) + return DataLoader(self.train_dataset, batch_size=self.train_batch_size, num_workers=self.train_num_workers, shuffle=False) def val_dataloader(self): return DataLoader(self.val_dataset, batch_size=self.train_batch_size, num_workers=self.train_num_workers, shuffle=False) @@ -106,7 +109,7 @@ class DataInterface(pl.LightningDataModule): class MILDataModule(pl.LightningDataModule): - def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, *args, **kwargs): + def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, backbone=None, *args, **kwargs): super().__init__() self.data_root = data_root self.label_path = label_path @@ -121,41 +124,74 @@ class MILDataModule(pl.LightningDataModule): self.num_bags_test = 50 self.seed = 1 + self.backbone = backbone self.cache = True + self.fe_transform = None def setup(self, stage: Optional[str] = None) -> None: - # if self.n_classes == 2: - # if stage in (None, 'fit'): - # dataset = HDF5Dataset(self.data_root, mode='train', n_classes=self.n_classes) - # a = int(len(dataset)* 0.8) - # b = int(len(dataset) - a) - # self.train_data, self.valid_data = random_split(dataset, [a, b]) - - # if stage in (None, 'test'): - # self.test_data = HDF5Dataset(self.data_root, mode='test', n_classes=self.n_classes) - # else: home = Path.cwd().parts[1] - # self.label_path = f'{home}/ylan/DeepGraft_project/code/split_debug.json' - # train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train_small.csv' - # test_csv = f'/{home}/ylan/DeepGraft_project/code/debug_test_small.csv' - if stage in (None, 'fit'): - dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes) - # print(len(dataset)) + dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes, backbone=self.backbone) a = int(len(dataset)* 0.8) b = int(len(dataset) - a) self.train_data, self.valid_data = random_split(dataset, [a, b]) if stage in (None, 'test'): - self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes) + self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, backbone=self.backbone) return super().setup(stage=stage) def train_dataloader(self) -> DataLoader: - return DataLoader(self.train_data, self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, + return DataLoader(self.train_data, self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, + #sampler=ImbalancedDatasetSampler(self.train_data) + def val_dataloader(self) -> DataLoader: + return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers) + + def test_dataloader(self) -> DataLoader: + return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers) + +class DataModule(pl.LightningDataModule): + + def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, backbone=None, *args, **kwargs): + super().__init__() + self.data_root = data_root + self.label_path = label_path + self.batch_size = batch_size + self.num_workers = num_workers + self.image_size = 384 + self.n_classes = n_classes + self.target_number = 9 + self.mean_bag_length = 10 + self.var_bag_length = 2 + self.num_bags_train = 200 + self.num_bags_test = 50 + self.seed = 1 + + self.backbone = backbone + self.cache = True + self.fe_transform = None + + + def setup(self, stage: Optional[str] = None) -> None: + home = Path.cwd().parts[1] + + if stage in (None, 'fit'): + dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes, backbone=self.backbone) + a = int(len(dataset)* 0.8) + b = int(len(dataset) - a) + self.train_data, self.valid_data = random_split(dataset, [a, b]) + + if stage in (None, 'test'): + self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, backbone=self.backbone) + + return super().setup(stage=stage) + + def train_dataloader(self) -> DataLoader: + return DataLoader(self.train_data, self.batch_size, num_workers=self.num_workers, shuffle=False,) #batch_transforms=self.transform, pseudo_batch_dim=True, + #sampler=ImbalancedDatasetSampler(self.train_data), def val_dataloader(self) -> DataLoader: return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers) diff --git a/fine_tune.py b/fine_tune.py new file mode 100644 index 0000000000000000000000000000000000000000..5898e755c5d3e4b666599ae324dc81be41425962 --- /dev/null +++ b/fine_tune.py @@ -0,0 +1,42 @@ +from transformers import AutoFeatureExtractor, ViTModel +from transformers import Trainer, TrainingArguments +from torchvision import models +import torch +from datasets.custom_dataloader import DinoDataloader + + + +def fine_tune_transformer(args): + + data_path = args.data_path + model = args.model + n_classes = args.n_classes + + if model == 'dino': + feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16') + model_ft = ViTModel.from_pretrained('facebook/dino-vitb16', num_labels=n_classes) + + training_args = TrainingArguments( + output_dir = f'logs/fine_tune/{model}', + per_device_train_batch_size=16, + evaluation_strategy="steps", + num_train_epochs=4, + fp16=True, + save_steps=100, + eval_steps=100, + logging_steps=10, + learning_rate=2e-4, + save_total_limit=2, + remove_unused_columns=False, + push_to_hub=False, + report_to='tensorboard', + load_best_model_at_end=True, + ) + + dataset = DinoDataloader(args.data_path, mode='train') #, transforms=transform + + trainer = Trainer( + model = model_ft, + args=training_args, + + ) \ No newline at end of file diff --git a/models/TransMIL.py b/models/TransMIL.py index ce40a26b37b1886bf5698ee4ab8ecf07c1e4e2c8..69089de0bd934ea38b2854f513295b21cbb73a03 100755 --- a/models/TransMIL.py +++ b/models/TransMIL.py @@ -86,7 +86,7 @@ class TransMIL(nn.Module): h = self.norm(h)[:,0] #---->predict - logits = self._fc2(torch.sigmoid(h)) #[B, n_classes] + logits = self._fc2(h) #[B, n_classes] Y_hat = torch.argmax(logits, dim=1) Y_prob = F.softmax(logits, dim = 1) results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat} diff --git a/models/__pycache__/TransMIL.cpython-39.pyc b/models/__pycache__/TransMIL.cpython-39.pyc index 4e1ddff6d6f3cbadfd7f1c0c4686a7806f5896e8..cb0eb6fd5085feda9604ec0f24420d85eb1d939d 100644 Binary files a/models/__pycache__/TransMIL.cpython-39.pyc and b/models/__pycache__/TransMIL.cpython-39.pyc differ diff --git a/models/__pycache__/model_interface.cpython-39.pyc b/models/__pycache__/model_interface.cpython-39.pyc index 3e81c0ccdd33e6fe68a5ffe1b602990134944c80..0bf337675c76de4097e7f7723b99de0120f9594c 100644 Binary files a/models/__pycache__/model_interface.cpython-39.pyc and b/models/__pycache__/model_interface.cpython-39.pyc differ diff --git a/models/__pycache__/resnet50.cpython-39.pyc b/models/__pycache__/resnet50.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfb660f9829304cb853031dd3d3274396afc38be Binary files /dev/null and b/models/__pycache__/resnet50.cpython-39.pyc differ diff --git a/models/model_interface.py b/models/model_interface.py index 1b0f6e19f8429b764e6fd45b96bf821981d0731e..60b5cc73d6bec977216cdde6bb8bcee3fd267034 100755 --- a/models/model_interface.py +++ b/models/model_interface.py @@ -12,18 +12,24 @@ from matplotlib import pyplot as plt from MyOptimizer import create_optimizer from MyLoss import create_loss from utils.utils import cross_entropy_torch +from timm.loss import AsymmetricLossSingleLabel +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy #----> import torch import torch.nn as nn import torch.nn.functional as F import torchmetrics +from torch import optim as optim #----> import pytorch_lightning as pl from .vision_transformer import vit_small from torchvision import models from torchvision.models import resnet +from transformers import AutoFeatureExtractor, ViTModel + +from captum.attr import LayerGradCam class ModelInterface(pl.LightningModule): @@ -33,13 +39,17 @@ class ModelInterface(pl.LightningModule): self.save_hyperparameters() self.load_model() self.loss = create_loss(loss) + # self.asl = AsymmetricLossSingleLabel() + # self.loss = LabelSmoothingCrossEntropy(smoothing=0.1) + + # self.loss = self.optimizer = optimizer self.n_classes = model.n_classes self.log_path = kargs['log'] #---->acc self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] - + # print(self.experiment) #---->Metrics if self.n_classes > 2: self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted') @@ -73,35 +83,12 @@ class ModelInterface(pl.LightningModule): #--->random self.shuffle = kargs['data'].data_shuffle self.count = 0 + self.backbone = kargs['backbone'] self.out_features = 512 if kargs['backbone'] == 'dino': - #---> dino feature extractor - arch = 'vit_small' - patch_size = 16 - n_last_blocks = 4 - # num_labels = 1000 - avgpool_patchtokens = False - home = Path.cwd().parts[1] - - weight_path = f'/{home}/ylan/workspace/dino/output/Aachen_2/checkpoint.pth' - model = vit_small(patch_size, num_classes=0) - # model.eval() - # set_parameter_requires_grad(model, feature_extracting) - for param in model.parameters(): - param.requires_grad = False - # print(model.embed_dim) - # embed_dim = model.embed_dim * (n_last_blocks + int(avgpool_patchtokens)) - # model.eval() - # print(embed_dim) - linear = nn.Linear(model.embed_dim, self.out_features) - linear.weight.data.normal_(mean=0.0, std=0.01) - linear.bias.data.zero_() - - self.model_ft = nn.Sequential( - model, - linear, - ) + self.feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16') + self.model_ft = ViTModel.from_pretrained('facebook/dino-vitb16') elif kargs['backbone'] == 'resnet18': resnet18 = models.resnet18(pretrained=True) modules = list(resnet18.children())[:-1] @@ -109,7 +96,6 @@ class ModelInterface(pl.LightningModule): res18 = nn.Sequential( *modules, - ) for param in res18.parameters(): param.requires_grad = False @@ -118,7 +104,7 @@ class ModelInterface(pl.LightningModule): nn.AdaptiveAvgPool2d(1), View((-1, 512)), nn.Linear(512, self.out_features), - nn.ReLU(), + nn.GELU(), ) elif kargs['backbone'] == 'resnet50': @@ -135,7 +121,17 @@ class ModelInterface(pl.LightningModule): nn.AdaptiveAvgPool2d(1), View((-1, 1024)), nn.Linear(1024, self.out_features), - nn.ReLU() + # nn.GELU() + ) + elif kargs['backbone'] == 'efficientnet': + efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_widese_b0', pretrained=True) + for param in efficientnet.parameters(): + param.requires_grad = False + # efn = list(efficientnet.children())[:-1] + efficientnet.classifier.fc = nn.Linear(1280, self.out_features) + self.model_ft = nn.Sequential( + efficientnet, + nn.GELU(), ) elif kargs['backbone'] == 'simple': #mil-ab attention feature_extracting = False @@ -151,21 +147,19 @@ class ModelInterface(pl.LightningModule): nn.ReLU(), ) - #---->remove v_num - # def get_progress_bar_dict(self): - # # don't show the version number - # items = super().get_progress_bar_dict() - # items.pop("v_num", None) - # return items - def training_step(self, batch, batch_idx): #---->inference + data, label, _ = batch label = label.float() - data = data.squeeze(0) + data = data.squeeze(0).float() + # print(data) # print(data.shape) - features = self.model_ft(data) - + if self.backbone == 'dino': + features = self.model_ft(**data) + features = features.last_hidden_state + else: + features = self.model_ft(data) features = features.unsqueeze(0) # print(features.shape) # features = features.squeeze() @@ -177,21 +171,28 @@ class ModelInterface(pl.LightningModule): #---->loss loss = self.loss(logits, label) + # loss = self.asl(logits, label.squeeze()) #---->acc log # print(label) - Y_hat = int(Y_hat) + # Y_hat = int(Y_hat) # if self.n_classes == 2: # Y = int(label[0][1]) # else: Y = torch.argmax(label) # Y = int(label[0]) self.data[Y]["count"] += 1 - self.data[Y]["correct"] += (Y_hat == Y) + self.data[Y]["correct"] += (int(Y_hat) == Y) - return {'loss': loss} + return {'loss': loss, 'Y_prob': Y_prob, 'Y_hat': Y_hat, 'label': label} def training_epoch_end(self, training_step_outputs): + # logits = torch.cat([x['logits'] for x in training_step_outputs], dim = 0) + probs = torch.cat([x['Y_prob'] for x in training_step_outputs]) + max_probs = torch.stack([x['Y_hat'] for x in training_step_outputs]) + # target = torch.stack([x['label'] for x in training_step_outputs], dim = 0) + target = torch.cat([x['label'] for x in training_step_outputs]) + target = torch.argmax(target, dim=1) for c in range(self.n_classes): count = self.data[c]["count"] correct = self.data[c]["correct"] @@ -202,12 +203,19 @@ class ModelInterface(pl.LightningModule): print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count)) self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] + # print('max_probs: ', max_probs) + # print('probs: ', probs) + if self.current_epoch % 10 == 0: + self.log_confusion_matrix(probs, target, stage='train') + + self.log('Train/auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True) + def validation_step(self, batch, batch_idx): data, label, _ = batch label = label.float() - data = data.squeeze(0) + data = data.squeeze(0).float() features = self.model_ft(data) features = features.unsqueeze(0) @@ -224,20 +232,23 @@ class ModelInterface(pl.LightningModule): self.data[Y]["count"] += 1 self.data[Y]["correct"] += (Y_hat.item() == Y) - return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y} + return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label} def validation_epoch_end(self, val_step_outputs): logits = torch.cat([x['logits'] for x in val_step_outputs], dim = 0) - probs = torch.cat([x['Y_prob'] for x in val_step_outputs], dim = 0) + # probs = torch.cat([x['Y_prob'] for x in val_step_outputs], dim = 0) + probs = torch.cat([x['Y_prob'] for x in val_step_outputs]) max_probs = torch.stack([x['Y_hat'] for x in val_step_outputs]) - target = torch.stack([x['label'] for x in val_step_outputs], dim = 0) + # target = torch.stack([x['label'] for x in val_step_outputs], dim = 0) + target = torch.cat([x['label'] for x in val_step_outputs]) + target = torch.argmax(target, dim=1) #----> # logits = logits.long() # target = target.squeeze().long() # logits = logits.squeeze(0) self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True) - self.log('auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True) + self.log('val_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True) # print(max_probs.squeeze(0).shape) # print(target.shape) @@ -245,12 +256,8 @@ class ModelInterface(pl.LightningModule): on_epoch = True, logger = True) #----> log confusion matrix - confmat = self.confusion_matrix(max_probs.squeeze(), target) - df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes)) - plt.figure() - fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure() - plt.close(fig_) - self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch) + self.log_confusion_matrix(probs, target, stage='val') + #---->acc log for c in range(self.n_classes): @@ -267,18 +274,12 @@ class ModelInterface(pl.LightningModule): if self.shuffle == True: self.count = self.count+1 random.seed(self.count*50) - - - - def configure_optimizers(self): - optimizer = create_optimizer(self.optimizer, self.model) - return [optimizer] def test_step(self, batch, batch_idx): data, label, _ = batch label = label.float() - data = data.squeeze(0) + data = data.squeeze(0).float() features = self.model_ft(data) features = features.unsqueeze(0) @@ -292,12 +293,14 @@ class ModelInterface(pl.LightningModule): self.data[Y]["count"] += 1 self.data[Y]["correct"] += (Y_hat.item() == Y) - return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y} + return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label} def test_epoch_end(self, output_results): - probs = torch.cat([x['Y_prob'] for x in output_results], dim = 0) + probs = torch.cat([x['Y_prob'] for x in output_results]) max_probs = torch.stack([x['Y_hat'] for x in output_results]) - target = torch.stack([x['label'] for x in output_results], dim = 0) + # target = torch.stack([x['label'] for x in output_results], dim = 0) + target = torch.cat([x['label'] for x in output_results]) + target = torch.argmax(target, dim=1) #----> auc = self.AUROC(probs, target.squeeze()) @@ -326,19 +329,16 @@ class ModelInterface(pl.LightningModule): print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count)) self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] - confmat = self.confusion_matrix(max_probs.squeeze(), target) - df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes)) - plt.figure() - fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure() - # plt.close(fig_) - # self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch) - plt.savefig(f'{self.log_path}/cm_test') - plt.close(fig_) - + self.log_confusion_matrix(probs, target, stage='test') #----> result = pd.DataFrame([metrics]) result.to_csv(self.log_path / 'result.csv') + def configure_optimizers(self): + # optimizer_ft = optim.Adam(self.model_ft.parameters(), lr=self.optimizer.lr*0.1) + optimizer = create_optimizer(self.optimizer, self.model) + return optimizer + def load_model(self): name = self.hparams.model.name @@ -350,6 +350,7 @@ class ModelInterface(pl.LightningModule): else: camel_name = name try: + Model = getattr(importlib.import_module( f'models.{name}'), camel_name) except: @@ -371,6 +372,20 @@ class ModelInterface(pl.LightningModule): args1.update(other_args) return Model(**args1) + def log_confusion_matrix(self, max_probs, target, stage): + confmat = self.confusion_matrix(max_probs.squeeze(), target) + df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes)) + plt.figure() + fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure() + # plt.close(fig_) + # plt.savefig(f'{self.log_path}/cm_e{self.current_epoch}') + self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch) + + if stage == 'test': + plt.savefig(f'{self.log_path}/cm_test') + plt.close(fig_) + # self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch) + class View(nn.Module): def __init__(self, shape): super().__init__() @@ -383,4 +398,5 @@ class View(nn.Module): # batch_size = input.size(0) # shape = (batch_size, *self.shape) out = input.view(*self.shape) - return out \ No newline at end of file + return out + diff --git a/models/resnet50.py b/models/resnet50.py new file mode 100644 index 0000000000000000000000000000000000000000..89e23d7460df4b37d9455d9ab7a07d4350fd802f --- /dev/null +++ b/models/resnet50.py @@ -0,0 +1,293 @@ +import torch +import torch.nn as nn +from .utils import load_state_dict_from_url + + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x): + return self._forward_impl(x) + + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained=False, progress=True, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + + +def resnet34(pretrained=False, progress=True, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + + +def resnet50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + + +def resnet101(pretrained=False, progress=True, **kwargs): + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + + +def resnet152(pretrained=False, progress=True, **kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) \ No newline at end of file diff --git a/train.py b/train.py index 182e3c0f4f0df1096166441731e39794f8599766..036d5ed183d4c8ee047ff3413bfb3d4730154b19 100644 --- a/train.py +++ b/train.py @@ -3,6 +3,8 @@ from pathlib import Path import numpy as np import glob +from sklearn.model_selection import KFold + from datasets.data_interface import DataInterface, MILDataModule from models.model_interface import ModelInterface import models.vision_transformer as vits @@ -11,25 +13,32 @@ from utils.utils import * # pytorch_lightning import pytorch_lightning as pl from pytorch_lightning import Trainer +from pytorch_lightning.loops import KFoldLoop +import torch #--->Setting parameters def make_parse(): parser = argparse.ArgumentParser() parser.add_argument('--stage', default='train', type=str) parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str) - # parser.add_argument('--gpus', default = [2]) + parser.add_argument('--gpus', default = 2, type=int) + parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str) parser.add_argument('--fold', default = 0) + parser.add_argument('--bag_size', default = 1024, type=int) args = parser.parse_args() return args #---->main def main(cfg): + torch.set_num_threads(16) + #---->Initialize seed pl.seed_everything(cfg.General.seed) #---->load loggers cfg.load_loggers = load_loggers(cfg) + # print(cfg.load_loggers) #---->load callbacks cfg.callbacks = load_callbacks(cfg) @@ -49,7 +58,10 @@ def main(cfg): 'batch_size': cfg.Data.train_dataloader.batch_size, 'num_workers': cfg.Data.train_dataloader.num_workers, 'n_classes': cfg.Model.n_classes, + 'backbone': cfg.Model.backbone, + 'bag_size': cfg.Data.bag_size, } + dm = MILDataModule(**DataInterface_dict) @@ -70,9 +82,9 @@ def main(cfg): callbacks=cfg.callbacks, max_epochs= cfg.General.epochs, min_epochs = 200, - gpus=cfg.General.gpus, - # gpus = [4], - # strategy='ddp', + # gpus=cfg.General.gpus, + gpus = [2,3], + strategy='ddp', amp_backend='native', # amp_level=cfg.General.amp_level, precision=cfg.General.precision, @@ -85,6 +97,7 @@ def main(cfg): #---->train or test if cfg.General.server == 'train': + trainer.fit_loop = KFoldLoop(3, trainer.fit_loop, ) trainer.fit(model = model, datamodule = dm) else: model_paths = list(cfg.log_path.glob('*.ckpt')) @@ -94,6 +107,7 @@ def main(cfg): new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg) trainer.test(model=new_model, datamodule=dm) + if __name__ == '__main__': args = make_parse() @@ -101,9 +115,12 @@ if __name__ == '__main__': #---->update cfg.config = args.config - # cfg.General.gpus = args.gpus + cfg.General.gpus = [args.gpus] cfg.General.server = args.stage cfg.Data.fold = args.fold + cfg.Loss.base_loss = args.loss + cfg.Data.bag_size = args.bag_size + #---->main main(cfg) diff --git a/utils/__pycache__/utils.cpython-39.pyc b/utils/__pycache__/utils.cpython-39.pyc index e1eeb958497b23b1f9e9934764083531720a66d0..704aade8ba3e8bcbedf237b3fde5db80b84fcd32 100644 Binary files a/utils/__pycache__/utils.cpython-39.pyc and b/utils/__pycache__/utils.cpython-39.pyc differ diff --git a/utils/extract_features.py b/utils/extract_features.py new file mode 100644 index 0000000000000000000000000000000000000000..fb040a175099d6e6612a7634a10eee07c4345cde --- /dev/null +++ b/utils/extract_features.py @@ -0,0 +1,27 @@ +## Choose Model and extract features from (augmented) image patches and save as .pt file + +from datasets.custom_dataloader import HDF5MILDataloader + + +def extract_features(input_dir, output_dir, model, batch_size): + + + dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes) + if model == 'resnet50': + model = Resnet50_baseline(pretrained = True) + model = model.to(device) + model.eval() + + + +if __name__ == '__main__': + + # input_dir, output_dir + # initiate data loader + # use data loader to load and augment images + # prediction from model + # choose save as bag or not (needed?) + + # features = torch.from_numpy(features) + # torch.save(features, output_path + '.pt') + diff --git a/utils/utils.py b/utils/utils.py index 96ed223d1f73fb9afbf951486376bc02c76ae75a..010eaab0c51e6fc784774502d85e7eef7303487e 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,4 +1,19 @@ from pathlib import Path +from abc import ABC, abstractclassmethod +import torch +import torchvision.transforms as T +from sklearn.model_selection import KFold +from torch.nn import functional as F +from torch.utils.data import random_split +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataset import Dataset, Subset +from torchmetrics.classification.accuracy import Accuracy + +from pytorch_lightning import LightningDataModule, seed_everything, Trainer +from pytorch_lightning.core.module import LightningModule +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.loops.fit_loop import FitLoop +from pytorch_lightning.trainer.states import TrainerFn #---->read yaml import yaml @@ -14,19 +29,32 @@ def load_loggers(cfg): log_path = cfg.General.log_path Path(log_path).mkdir(exist_ok=True, parents=True) - log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}' + log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}' version_name = Path(cfg.config).name[:-5] - cfg.log_path = Path(log_path) / log_name / version_name / f'fold{cfg.Data.fold}' - print(f'---->Log dir: {cfg.log_path}') + #---->TensorBoard - tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name), - name = version_name, version = f'fold{cfg.Data.fold}', - log_graph = True, default_hp_metric = False) - #---->CSV - csv_logger = pl_loggers.CSVLogger(log_path+str(log_name), - name = version_name, version = f'fold{cfg.Data.fold}', ) + if cfg.stage != 'test': + cfg.log_path = Path(log_path) / log_name / version_name / f'fold{cfg.Data.fold}' + tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name), + name = version_name, version = f'fold{cfg.Data.fold}', + log_graph = True, default_hp_metric = False) + #---->CSV + csv_logger = pl_loggers.CSVLogger(log_path+str(log_name), + name = version_name, version = f'fold{cfg.Data.fold}', ) + else: + cfg.log_path = Path(log_path) / log_name / version_name / f'test' + tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name), + name = version_name, version = f'test', + log_graph = True, default_hp_metric = False) + #---->CSV + csv_logger = pl_loggers.CSVLogger(log_path+str(log_name), + name = version_name, version = f'test', ) + + print(f'---->Log dir: {cfg.log_path}') + + # return tb_logger return [tb_logger, csv_logger] @@ -74,6 +102,14 @@ def load_callbacks(cfg): save_top_k = 1, mode = 'min', save_weights_only = True)) + Mycallbacks.append(ModelCheckpoint(monitor = 'val_auc', + dirpath = str(cfg.log_path), + filename = '{epoch:02d}-{val_auc:.4f}', + verbose = True, + save_last = True, + save_top_k = 1, + mode = 'max', + save_weights_only = True)) return Mycallbacks #---->val loss @@ -84,3 +120,103 @@ def cross_entropy_torch(x, y): x_log = torch.tensor([torch.log(x_softmax[i][y[i]]) for i in range(y.shape[0])]) loss = - torch.sum(x_log) / y.shape[0] return loss + +#-----> convert labels for task +label_map = { + 'bin': {'0': 0, '1': 1, '2': 1, '3': 1, '4': 1, '5': None}, + 'tcmr_viral': {'0': None, '1': 0, '2': None, '3': None, '4': 1, '5': None}, + 'no_viral': {'0': 0, '1': 1, '2': 2, '3': 3, '4': None, '5': None}, + 'no_other': {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': None}, + 'no_stable': {'0': None, '1': 1, '2': 2, '3': 3, '4': None, '5': None}, + 'all': {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5}, + +} +def convert_labels_for_task(task, label): + + return label_map[task][label] + + +#-----> KFOLD LOOP + +class KFoldLoop(Loop): + def __init__(self, num_folds: int, export_path: str) -> None: + super().__init__() + self.num_folds = num_folds + self.current_fold: int = 0 + self.export_path = export_path + + @property + def done(self) -> bool: + return self.current_fold >= self.num_folds + + def connect(self, fit_loop: FitLoop) -> None: + self.fit_loop = fit_loop + + def reset(self) -> None: + """Nothing to reset in this loop.""" + + def on_run_start(self, *args: Any, **kwargs: Any) -> None: + """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the + model.""" + assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) + self.trainer.datamodule.setup_folds(self.num_folds) + self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict()) + + def on_advance_start(self, *args: Any, **kwargs: Any) -> None: + """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance.""" + print(f"STARTING FOLD {self.current_fold}") + assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) + self.trainer.datamodule.setup_fold_index(self.current_fold) + + def advance(self, *args: Any, **kwargs: Any) -> None: + """Used to the run a fitting and testing on the current hold.""" + self._reset_fitting() # requires to reset the tracking stage. + self.fit_loop.run() + + self._reset_testing() # requires to reset the tracking stage. + self.trainer.test_loop.run() + self.current_fold += 1 # increment fold tracking number. + + def on_advance_end(self) -> None: + """Used to save the weights of the current fold and reset the LightningModule and its optimizers.""" + self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt")) + # restore the original weights + optimizers and schedulers. + self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict) + self.trainer.strategy.setup_optimizers(self.trainer) + self.replace(fit_loop=FitLoop) + + def on_run_end(self) -> None: + """Used to compute the performance of the ensemble model on the test set.""" + checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)] + voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths) + voting_model.trainer = self.trainer + # This requires to connect the new model and move it the right device. + self.trainer.strategy.connect(voting_model) + self.trainer.strategy.model_to_device() + self.trainer.test_loop.run() + + def on_save_checkpoint(self) -> Dict[str, int]: + return {"current_fold": self.current_fold} + + def on_load_checkpoint(self, state_dict: Dict) -> None: + self.current_fold = state_dict["current_fold"] + + def _reset_fitting(self) -> None: + self.trainer.reset_train_dataloader() + self.trainer.reset_val_dataloader() + self.trainer.state.fn = TrainerFn.FITTING + self.trainer.training = True + + def _reset_testing(self) -> None: + self.trainer.reset_test_dataloader() + self.trainer.state.fn = TrainerFn.TESTING + self.trainer.testing = True + + def __getattr__(self, key) -> Any: + # requires to be overridden as attributes of the wrapped loop are being accessed. + if key not in self.__dict__: + return getattr(self.fit_loop, key) + return self.__dict__[key] + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__dict__.update(state) \ No newline at end of file