diff --git a/.gitignore b/.gitignore index 4cf8dd15619e7c11d325ae0eb80bba874a99f06d..c9e4fdabe0660f6872b93a3d8a59aa1750a671e0 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ -logs/* \ No newline at end of file +logs/* +lightning_logs/* +test/* \ No newline at end of file diff --git a/DeepGraft/AttMIL_resnet18_debug.yaml b/DeepGraft/AttMIL_resnet18_debug.yaml new file mode 100644 index 0000000000000000000000000000000000000000..03ebd7e60a5af9d32bf1442455e24fc9528557f8 --- /dev/null +++ b/DeepGraft/AttMIL_resnet18_debug.yaml @@ -0,0 +1,51 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [1] + epochs: &epoch 1 + grad_acc: 2 + frozen_bn: False + patience: 2 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_debug.json' + fold: 0 + nfold: 2 + cross_val: False + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + + + +Model: + name: AttMIL + n_classes: 2 + backbone: simple + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/AttMIL_simple_no_other.yaml b/DeepGraft/AttMIL_simple_no_other.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ae90a80ee152a2bd9d5b04c059fff38b1e77c06f --- /dev/null +++ b/DeepGraft/AttMIL_simple_no_other.yaml @@ -0,0 +1,49 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 200 + grad_acc: 2 + frozen_bn: False + patience: 20 + server: train #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_other.json' + fold: 1 + nfold: 3 + cross_val: False + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: AttMIL + n_classes: 5 + backbone: simple + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/AttMIL_simple_no_viral.yaml b/DeepGraft/AttMIL_simple_no_viral.yaml new file mode 100644 index 0000000000000000000000000000000000000000..37ee07479e014d9277f582196eb5a50f4d74e2de --- /dev/null +++ b/DeepGraft/AttMIL_simple_no_viral.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 50 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: AttMIL + n_classes: 4 + backbone: simple + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/AttMIL_simple_tcmr_viral.yaml b/DeepGraft/AttMIL_simple_tcmr_viral.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c982d3ad1bae365a2497a599a805dadf9874a9c2 --- /dev/null +++ b/DeepGraft/AttMIL_simple_tcmr_viral.yaml @@ -0,0 +1,49 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 300 + grad_acc: 2 + frozen_bn: False + patience: 20 + server: train #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json' + fold: 1 + nfold: 3 + cross_val: False + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: AttMIL + n_classes: 2 + backbone: simple + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_efficientnet_no_other.yaml b/DeepGraft/TransMIL_efficientnet_no_other.yaml index 79d8ea88865ebddd6a918bdc4b9435b6ba973ff1..7687a0c971796c6b31d199961343047e612c91dd 100644 --- a/DeepGraft/TransMIL_efficientnet_no_other.yaml +++ b/DeepGraft/TransMIL_efficientnet_no_other.yaml @@ -6,10 +6,10 @@ General: precision: 16 multi_gpu_mode: dp gpus: [0] - epochs: &epoch 1000 + epochs: &epoch 200 grad_acc: 2 frozen_bn: False - patience: 200 + patience: 20 server: test #train #test log_path: logs/ diff --git a/DeepGraft/TransMIL_efficientnet_no_viral.yaml b/DeepGraft/TransMIL_efficientnet_no_viral.yaml index 8780060ebd1475b273b06498800523d7108085b1..98fe3778c9528e38d027b1f86d4d4b18631b0fe8 100644 --- a/DeepGraft/TransMIL_efficientnet_no_viral.yaml +++ b/DeepGraft/TransMIL_efficientnet_no_viral.yaml @@ -5,11 +5,11 @@ General: amp_level: O2 precision: 16 multi_gpu_mode: dp - gpus: [3] - epochs: &epoch 500 + gpus: [0, 2] + epochs: &epoch 200 grad_acc: 2 frozen_bn: False - patience: 200 + patience: 20 server: test #train #test log_path: logs/ diff --git a/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml index f69b5bfadaa3023c9b6113e59681daee73b29fe2..52230329255872d150dbc4d0552d1265dc3abc80 100644 --- a/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml +++ b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml @@ -5,7 +5,7 @@ General: amp_level: O2 precision: 16 multi_gpu_mode: dp - gpus: [3] + gpus: [0] epochs: &epoch 500 grad_acc: 2 frozen_bn: False @@ -16,10 +16,11 @@ General: Data: dataset_name: custom data_shuffle: False - data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + data_dir: '/home/ylan/data/DeepGraft/256_256um/' label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json' fold: 1 - nfold: 4 + nfold: 3 + cross_val: True train_dataloader: batch_size: 1 @@ -33,6 +34,7 @@ Model: name: TransMIL n_classes: 2 backbone: efficientnet + in_features: 512 Optimizer: diff --git a/DeepGraft/TransMIL_debug.yaml b/DeepGraft/TransMIL_resnet18_debug.yaml similarity index 91% rename from DeepGraft/TransMIL_debug.yaml rename to DeepGraft/TransMIL_resnet18_debug.yaml index d83ce0d902228644d4fdd8a3059cd4b135a69fde..29bfa1e34ef4250bedd587bb736a9309099b1c0c 100644 --- a/DeepGraft/TransMIL_debug.yaml +++ b/DeepGraft/TransMIL_resnet18_debug.yaml @@ -6,10 +6,10 @@ General: precision: 16 multi_gpu_mode: dp gpus: [1] - epochs: &epoch 200 + epochs: &epoch 1 grad_acc: 2 frozen_bn: False - patience: 200 + patience: 2 server: test #train #test log_path: logs/ @@ -19,7 +19,8 @@ Data: data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' label_file: '/home/ylan/DeepGraft/training_tables/split_debug.json' fold: 0 - nfold: 4 + nfold: 2 + cross_val: True train_dataloader: batch_size: 1 diff --git a/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml index a756616cd4a128874731b17353d108856f72e9f4..f6e469763b0a6df74a2ce2306ae7e94f53d9770b 100644 --- a/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml +++ b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml @@ -5,8 +5,8 @@ General: amp_level: O2 precision: 16 multi_gpu_mode: dp - gpus: [3] - epochs: &epoch 500 + gpus: [0] + epochs: &epoch 200 grad_acc: 2 frozen_bn: False patience: 50 @@ -19,7 +19,8 @@ Data: data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' label_file: '/home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json' fold: 1 - nfold: 4 + nfold: 3 + cross_val: True train_dataloader: batch_size: 1 diff --git a/__pycache__/train_loop.cpython-39.pyc b/__pycache__/train_loop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65c5ca839312787cce6aed19f346ae40a7c04d83 Binary files /dev/null and b/__pycache__/train_loop.cpython-39.pyc differ diff --git a/datasets/__pycache__/custom_dataloader.cpython-39.pyc b/datasets/__pycache__/custom_dataloader.cpython-39.pyc index 4a200bbb7d34328019d9bb9e604b0524127ae863..99e793aa4b84b2a5e9ee1882437346e6e4a33002 100644 Binary files a/datasets/__pycache__/custom_dataloader.cpython-39.pyc and b/datasets/__pycache__/custom_dataloader.cpython-39.pyc differ diff --git a/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc b/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20aefffd1a14afe3c499f646e9eb674810896281 Binary files /dev/null and b/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc differ diff --git a/datasets/__pycache__/data_interface.cpython-39.pyc b/datasets/__pycache__/data_interface.cpython-39.pyc index 9550db1509e8d477a1c15534a76c6f87976fbec4..59584d31b16ee5ba7d08220a2174fb5e79e9aaed 100644 Binary files a/datasets/__pycache__/data_interface.cpython-39.pyc and b/datasets/__pycache__/data_interface.cpython-39.pyc differ diff --git a/datasets/custom_dataloader.py b/datasets/custom_dataloader.py index 02850f5db5e263a9d5cdbbfff923952dd5dbfa52..cf65534d749f1cdc764bbdd8193245043c7853f6 100644 --- a/datasets/custom_dataloader.py +++ b/datasets/custom_dataloader.py @@ -41,7 +41,7 @@ class HDF5MILDataloader(data.Dataset): data_cache_size: Number of HDF5 files that can be cached in the cache (default=3). """ - def __init__(self, file_path, label_path, mode, n_classes, backbone=None, load_data=False, data_cache_size=20, bag_size=1024): + def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=10, bag_size=1024): super().__init__() self.data_info = [] @@ -55,7 +55,6 @@ class HDF5MILDataloader(data.Dataset): self.label_path = label_path self.n_classes = n_classes self.bag_size = bag_size - self.backbone = backbone # self.label_file = label_path recursive = True @@ -134,10 +133,6 @@ class HDF5MILDataloader(data.Dataset): RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)), transforms.ToTensor() ]) - if self.backbone == 'dino': - self.feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16') - - # self._add_data_infos(load_data) def __getitem__(self, index): # get data @@ -150,9 +145,6 @@ class HDF5MILDataloader(data.Dataset): # print(img.shape) for img in batch: # expects numpy img = img.numpy().astype(np.uint8) - if self.backbone == 'dino': - img = self.feature_extractor(images=img, return_tensors='pt') - # img = self.resize_transforms(img) img = seq_img_d.augment_image(img) img = self.val_transforms(img) out_batch.append(img) @@ -160,23 +152,24 @@ class HDF5MILDataloader(data.Dataset): else: for img in batch: img = img.numpy().astype(np.uint8) - if self.backbone == 'dino': - img = self.feature_extractor(images=img, return_tensors='pt') - img = self.resize_transforms(img) - img = self.val_transforms(img) out_batch.append(img) - if len(out_batch) == 0: - # print(name) - out_batch = torch.randn(self.bag_size,3,256,256) - else: out_batch = torch.stack(out_batch) + # if len(out_batch) == 0: + # # print(name) + # out_batch = torch.randn(self.bag_size,3,256,256) + # else: + out_batch = torch.stack(out_batch) # print(out_batch.shape) # out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch - + # print(out_batch.shape) + if out_batch.shape != torch.Size([self.bag_size, 256, 256, 3]) and out_batch.shape != torch.Size([self.bag_size, 3,256,256]): + print(name) + print(out_batch.shape) + out_batch = torch.permute(out_batch, (0, 2,1,3)) label = torch.as_tensor(label) label = torch.nn.functional.one_hot(label, num_classes=self.n_classes) - return out_batch, label, name + return out_batch, label, name #, name_batch def __len__(self): return len(self.data_info) @@ -184,7 +177,9 @@ class HDF5MILDataloader(data.Dataset): def _add_data_infos(self, file_path, load_data): wsi_name = Path(file_path).stem if wsi_name in self.slideLabelDict: + # if wsi_name[:2] != 'RU': #skip RU because of container problems in dataset label = self.slideLabelDict[wsi_name] + # print(wsi_name) idx = -1 self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx}) @@ -195,14 +190,29 @@ class HDF5MILDataloader(data.Dataset): """ with h5py.File(file_path, 'r') as h5_file: wsi_batch = [] + tile_names = [] for tile in h5_file.keys(): + img = h5_file[tile][:] img = img.astype(np.uint8) img = torch.from_numpy(img) # img = self.resize_transforms(img) + wsi_batch.append(img) - wsi_batch = torch.stack(wsi_batch) - wsi_batch, _ = to_fixed_size_bag(wsi_batch, self.bag_size) + tile_names.append(tile) + + # print('Empty Container: ', file_path) #/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/RU0248_PASD_jke_PASD_20200201_195900_BIG.hdf5 + + if wsi_batch: + wsi_batch = torch.stack(wsi_batch) + else: + print('Empty Container: ', file_path) + wsi_batch = torch.randn(self.bag_size,3,256,256) + + if wsi_batch.shape[1:] != torch.Size([3, 256, 256]) and wsi_batch.shape[1:] != torch.Size([256, 256, 3]): + print(file_path) + print(wsi_batch.shape) + wsi_batch, name_batch = to_fixed_size_bag(wsi_batch, self.bag_size) idx = self._add_to_cache(wsi_batch, file_path) file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path) self.data_info[file_idx + idx]['cache_idx'] = idx @@ -461,16 +471,19 @@ class RandomHueSaturationValue(object): img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) return img #, lbl -def to_fixed_size_bag(bag: torch.Tensor, bag_size: int = 512): +def to_fixed_size_bag(bag, bag_size: int = 512): # get up to bag_size elements bag_idxs = torch.randperm(bag.shape[0])[:bag_size] bag_samples = bag[bag_idxs] + # bag_sample_names = [bag_names[i] for i in bag_idxs] # zero-pad if we don't have enough samples zero_padded = torch.cat((bag_samples, torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3]))) + # zero_padded_names = bag_sample_names + ['']*(bag_size - len(bag_sample_names)) return zero_padded, min(bag_size, len(bag)) + # return zero_padded, zero_padded_names, min(bag_size, len(bag)) class RandomHueSaturationValue(object): @@ -510,11 +523,11 @@ if __name__ == '__main__': train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv' data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/' # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json' - label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json' + label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_no_other.json' output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/augments' os.makedirs(output_path, exist_ok=True) - n_classes = 2 + n_classes = 5 dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes, bag_size=20) # print(dataset.dataset) @@ -528,15 +541,19 @@ if __name__ == '__main__': # print(len(dataset)) # # x = 0 + #/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/RU0248_PASD_jke_PASD_20200201_195900_BIG.hdf5 c = 0 label_count = [0] *n_classes for item in dl: - # if c >=10: - # break + if c >=10: + break bag, label, name = item - label_count[np.argmax(label)] += 1 - print(label_count) - print(len(train_ds)) + print(name) + # if name == 'RU0248_PASD_jke_PASD_20200201_195900_BIG': + + # print(bag) + # print(label) + c += 1 # # # print(bag.shape) # # if bag.shape[1] == 1: # # print(name) @@ -578,7 +595,7 @@ if __name__ == '__main__': # o_img = Image.fromarray(o_img) # o_img = o_img.convert('RGB') # o_img.save(f'{output_path}/{i}_original.png') - # c += 1 + # break # else: break # print(data.shape) diff --git a/datasets/custom_jpg_dataloader.py b/datasets/custom_jpg_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..722b27593ee28ebbbcf42249cd9545bdde6ed85c --- /dev/null +++ b/datasets/custom_jpg_dataloader.py @@ -0,0 +1,459 @@ +''' +ToDo: remove bag_size +''' + + +import numpy as np +from pathlib import Path +import torch +from torch.utils import data +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm +import torchvision.transforms as transforms +from PIL import Image +import cv2 +import json +import albumentations as A +from imgaug import augmenters as iaa +import imgaug as ia +from torchsampler import ImbalancedDatasetSampler + + +class RangeNormalization(object): + def __call__(self, sample): + img = sample + return (img / 255.0 - 0.5) / 0.5 + +class JPGMILDataloader(data.Dataset): + """Represents an abstract HDF5 dataset. For single H5 container! + + Input params: + file_path: Path to the folder containing the dataset (one or multiple HDF5 files). + mode: 'train' or 'test' + load_data: If True, loads all the data immediately into RAM. Use this if + the dataset is fits into memory. Otherwise, leave this at false and + the data will load lazily. + data_cache_size: Number of HDF5 files that can be cached in the cache (default=3). + + """ + def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=10, bag_size=1024): + super().__init__() + + self.data_info = [] + self.data_cache = {} + self.slideLabelDict = {} + self.files = [] + self.data_cache_size = data_cache_size + self.mode = mode + self.file_path = file_path + # self.csv_path = csv_path + self.label_path = label_path + self.n_classes = n_classes + self.bag_size = bag_size + self.empty_slides = [] + # self.label_file = label_path + recursive = True + + # read labels and slide_path from csv + with open(self.label_path, 'r') as f: + temp_slide_label_dict = json.load(f)[mode] + for (x, y) in temp_slide_label_dict: + x = Path(x).stem + + # x_complete_path = Path(self.file_path)/Path(x) + for cohort in Path(self.file_path).iterdir(): + x_complete_path = Path(self.file_path) / cohort / 'BLOCKS' / Path(x) + if x_complete_path.is_dir(): + if len(list(x_complete_path.iterdir())) > 50: + # print(x_complete_path) + self.slideLabelDict[x] = y + self.files.append(x_complete_path) + else: self.empty_slides.append(x_complete_path) + # print(len(self.empty_slides)) + # print(self.empty_slides) + + + for slide_dir in tqdm(self.files): + self._add_data_infos(str(slide_dir.resolve()), load_data) + + + self.resize_transforms = A.Compose([ + A.SmallestMaxSize(max_size=256) + ]) + sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1") + sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2") + sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3") + sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4") + sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5") + + self.train_transforms = iaa.Sequential([ + iaa.AddToHueAndSaturation(value=(-13, 13), name="MyHSV"), + sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")), + iaa.Fliplr(0.5, name="MyFlipLR"), + iaa.Flipud(0.5, name="MyFlipUD"), + sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")), + iaa.OneOf([ + sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")), + sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")), + sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine")) + ], name="MyOneOf") + + ], name="MyAug") + + # self.train_transforms = A.Compose([ + # A.HueSaturationValue(hue_shift_limit=13, sat_shift_limit=2, val_shift_limit=0, always_apply=True, p=1.0), + # # A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=0, val_shift_limit=0, always_apply=False, p=0.5), + # # A.RandomGamma(), + # # A.HorizontalFlip(), + # # A.VerticalFlip(), + # # A.RandomRotate90(), + # # A.OneOf([ + # # A.ElasticTransform(alpha=150, sigma=20, alpha_affine=50), + # # A.Affine( + # # scale={'x': (0.95, 1.05), 'y': (0.95, 1.05)}, + # # rotate=(-45, 45), + # # shear=(-4, 4), + # # cval=8, + # # ) + # # ]), + # A.Normalize(), + # ToTensorV2(), + # ]) + self.val_transforms = transforms.Compose([ + # A.Normalize(), + # ToTensorV2(), + RangeNormalization(), + transforms.ToTensor(), + + ]) + self.img_transforms = transforms.Compose([ + transforms.RandomHorizontalFlip(p=1), + transforms.RandomVerticalFlip(p=1), + # histoTransforms.AutoRandomRotation(), + transforms.Lambda(lambda a: np.array(a)), + ]) + self.hsv_transforms = transforms.Compose([ + RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)), + transforms.ToTensor() + ]) + + def __getitem__(self, index): + # get data + batch, label, name = self.get_data(index) + out_batch = [] + seq_img_d = self.train_transforms.to_deterministic() + + if self.mode == 'train': + # print(img) + # print(.shape) + for img in batch: # expects numpy + img = img.numpy().astype(np.uint8) + # print(img.shape) + img = seq_img_d.augment_image(img) + img = self.val_transforms(img) + out_batch.append(img) + + else: + for img in batch: + img = img.numpy().astype(np.uint8) + img = self.val_transforms(img) + out_batch.append(img) + + # if len(out_batch) == 0: + # # print(name) + # out_batch = torch.randn(self.bag_size,3,256,256) + # else: + out_batch = torch.stack(out_batch) + # print(out_batch.shape) + # out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch + # print(out_batch.shape) + # if out_batch.shape != torch.Size([self.bag_size, 256, 256, 3]) and out_batch.shape != torch.Size([self.bag_size, 3,256,256]): + # print(name) + # print(out_batch.shape) + # out_batch = torch.permute(out_batch, (0, 2,1,3)) + label = torch.as_tensor(label) + label = torch.nn.functional.one_hot(label, num_classes=self.n_classes) + # print(out_batch) + return out_batch, label, name #, name_batch + + def __len__(self): + return len(self.data_info) + + def _add_data_infos(self, file_path, load_data): + wsi_name = Path(file_path).stem + if wsi_name in self.slideLabelDict: + # if wsi_name[:2] != 'RU': #skip RU because of container problems in dataset + label = self.slideLabelDict[wsi_name] + # print(wsi_name) + idx = -1 + self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx}) + + def _load_data(self, file_path): + """Load data to the cache given the file + path and update the cache index in the + data_info structure. + """ + wsi_batch = [] + tile_names = [] + # print(wsi_batch) + # for tile_path in Path(file_path).iterdir(): + # print(tile_path) + for tile_path in Path(file_path).iterdir(): + # print(tile_path) + img = np.asarray(Image.open(tile_path)).astype(np.uint8) + img = torch.from_numpy(img) + + # print(wsi_batch) + wsi_batch.append(img) + + tile_names.append(tile_path.stem) + + # if wsi_batch: + wsi_batch = torch.stack(wsi_batch) + if len(wsi_batch.shape) < 4: + wsi_batch.unsqueeze(0) + # else: + # print('Empty Container: ', file_path) + # self.empty_slides.append(file_path) + # wsi_batch = torch.randn(self.bag_size,256,256,3) + # print(wsi_batch.shape) + # if wsi_batch.shape[1:] != torch.Size([3, 256, 256]) and wsi_batch.shape[1:] != torch.Size([256, 256, 3]): + # print(file_path) + # print(wsi_batch.shape) + # wsi_batch, name_batch = to_fixed_size_bag(wsi_batch, self.bag_size) + idx = self._add_to_cache(wsi_batch, file_path) + file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path) + self.data_info[file_idx + idx]['cache_idx'] = idx + + # remove an element from data cache if size was exceeded + if len(self.data_cache) > self.data_cache_size: + # remove one item from the cache at random + removal_keys = list(self.data_cache) + removal_keys.remove(file_path) + self.data_cache.pop(removal_keys[0]) + # remove invalid cache_idx + # self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'shape': di['shape'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info] + self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info] + + def _add_to_cache(self, data, data_path): + """Adds data to the cache and returns its index. There is one cache + list for every file_path, containing all datasets in that file. + """ + if data_path not in self.data_cache: + self.data_cache[data_path] = [data] + else: + self.data_cache[data_path].append(data) + return len(self.data_cache[data_path]) - 1 + + # def get_data_infos(self, type): + # """Get data infos belonging to a certain type of data. + # """ + # data_info_type = [di for di in self.data_info if di['type'] == type] + # return data_info_type + + def get_name(self, i): + # name = self.get_data_infos(type)[i]['name'] + name = self.data_info[i]['name'] + return name + + def get_labels(self, indices): + + return [self.data_info[i]['label'] for i in indices] + # return self.slideLabelDict.values() + + def get_data(self, i): + """Call this function anytime you want to access a chunk of data from the + dataset. This will make sure that the data is loaded in case it is + not part of the data cache. + i = index + """ + # fp = self.get_data_infos(type)[i]['data_path'] + fp = self.data_info[i]['data_path'] + if fp not in self.data_cache: + self._load_data(fp) + + # get new cache_idx assigned by _load_data_info + # cache_idx = self.get_data_infos(type)[i]['cache_idx'] + cache_idx = self.data_info[i]['cache_idx'] + label = self.data_info[i]['label'] + name = self.data_info[i]['name'] + # print(self.data_cache[fp][cache_idx]) + return self.data_cache[fp][cache_idx], label, name + + + +class RandomHueSaturationValue(object): + + def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5): + + self.hue_shift_limit = hue_shift_limit + self.sat_shift_limit = sat_shift_limit + self.val_shift_limit = val_shift_limit + self.p = p + + def __call__(self, sample): + + img = sample #,lbl + + if np.random.random() < self.p: + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32 + h, s, v = cv2.split(img) + hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1) + hue_shift = np.uint8(hue_shift) + h += hue_shift + sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1]) + s = cv2.add(s, sat_shift) + val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1]) + v = cv2.add(v, val_shift) + img = cv2.merge((h, s, v)) + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) + return img #, lbl + +def to_fixed_size_bag(bag, bag_size: int = 512): + + #duplicate bag instances unitl + + # get up to bag_size elements + bag_idxs = torch.randperm(bag.shape[0])[:bag_size] + bag_samples = bag[bag_idxs] + # bag_sample_names = [bag_names[i] for i in bag_idxs] + q, r = divmod(bag_size, bag_samples.shape[0]) + if q > 0: + bag_samples = torch.cat([bag_samples]*q, 0) + + self_padded = torch.cat([bag_samples, bag_samples[:r,:, :, :]]) + + # zero-pad if we don't have enough samples + # zero_padded = torch.cat((bag_samples, + # torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3]))) + + return self_padded, min(bag_size, len(bag)) + + +class RandomHueSaturationValue(object): + + def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5): + + self.hue_shift_limit = hue_shift_limit + self.sat_shift_limit = sat_shift_limit + self.val_shift_limit = val_shift_limit + self.p = p + + def __call__(self, sample): + + img = sample #,lbl + + if np.random.random() < self.p: + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32 + h, s, v = cv2.split(img) + hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1) + hue_shift = np.uint8(hue_shift) + h += hue_shift + sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1]) + s = cv2.add(s, sat_shift) + val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1]) + v = cv2.add(v, val_shift) + img = cv2.merge((h, s, v)) + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) + return img #, lbl + + + +if __name__ == '__main__': + from pathlib import Path + import os + + home = Path.cwd().parts[1] + train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv' + data_root = f'/{home}/ylan/data/DeepGraft/256_256um' + # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json' + label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json' + output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/augments' + os.makedirs(output_path, exist_ok=True) + + n_classes = 2 + + dataset = JPGMILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes, bag_size=20) + # print(dataset.dataset) + # a = int(len(dataset)* 0.8) + # b = int(len(dataset) - a) + # train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b]) + dl = DataLoader(dataset, None, num_workers=1) + print(len(dl)) + dl = DataLoader(dataset, None, sampler=ImbalancedDatasetSampler(dataset), num_workers=5) + + + + # data = DataLoader(dataset, batch_size=1) + + # print(len(dataset)) + # # x = 0 + #/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/RU0248_PASD_jke_PASD_20200201_195900_BIG.hdf5 + c = 0 + label_count = [0] *n_classes + print(len(dl)) + for item in dl: + # if c >=10: + # break + bag, label, name = item + # print(label) + label_count[torch.argmax(label)] += 1 + # print(name) + # if name == 'RU0248_PASD_jke_PASD_20200201_195900_BIG': + + # print(bag) + # print(label) + c += 1 + print(label_count) + # # # print(bag.shape) + # # if bag.shape[1] == 1: + # # print(name) + # # print(bag.shape) + # print(bag.shape) + + # # out_dir = Path(output_path) / name + # # os.makedirs(out_dir, exist_ok=True) + + # # # print(item[2]) + # # # print(len(item)) + # # # print(item[1]) + # # # print(data.shape) + # # # data = data.squeeze() + # # bag = item[0] + # bag = bag.squeeze() + # original = original.squeeze() + # for i in range(bag.shape[0]): + # img = bag[i, :, :, :] + # img = img.squeeze() + + # img = ((img-img.min())/(img.max() - img.min())) * 255 + # print(img) + # # print(img) + # img = img.numpy().astype(np.uint8).transpose(1,2,0) + + + # img = Image.fromarray(img) + # img = img.convert('RGB') + # img.save(f'{output_path}/{i}.png') + + + + # o_img = original[i,:,:,:] + # o_img = o_img.squeeze() + # print(o_img.shape) + # o_img = ((o_img-o_img.min())/(o_img.max()-o_img.min()))*255 + # o_img = o_img.numpy().astype(np.uint8).transpose(1,2,0) + # o_img = Image.fromarray(o_img) + # o_img = o_img.convert('RGB') + # o_img.save(f'{output_path}/{i}_original.png') + + # break + # else: break + # print(data.shape) + # print(label) + # a = [torch.Tensor((3,256,256))]*3 + # b = torch.stack(a) + # print(b) + # c = to_fixed_size_bag(b, 512) + # print(c) \ No newline at end of file diff --git a/datasets/data_interface.py b/datasets/data_interface.py index 056e6ff09bcf768c91e0f51e44640487ea4fbee1..efa104c4f3a1c0ce12ece5cfc285202c8f8f1825 100644 --- a/datasets/data_interface.py +++ b/datasets/data_interface.py @@ -2,15 +2,25 @@ import inspect # 查看python 类的参数和模块、函数代码 import importlib # In order to dynamically import the library from typing import Optional import pytorch_lightning as pl +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.loops.fit_loop import FitLoop + from torch.utils.data import random_split, DataLoader +from torch.utils.data.dataset import Dataset, Subset from torchvision.datasets import MNIST from torchvision import transforms from .camel_dataloader import FeatureBagLoader from .custom_dataloader import HDF5MILDataloader +from .custom_jpg_dataloader import JPGMILDataloader from pathlib import Path from transformers import AutoFeatureExtractor from torchsampler import ImbalancedDatasetSampler +from abc import ABC, abstractclassmethod, abstractmethod +from sklearn.model_selection import KFold + + + class DataInterface(pl.LightningDataModule): def __init__(self, train_batch_size=64, train_num_workers=8, test_batch_size=1, test_num_workers=1,dataset_name=None, **kwargs): @@ -109,7 +119,7 @@ class DataInterface(pl.LightningDataModule): class MILDataModule(pl.LightningDataModule): - def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, backbone=None, *args, **kwargs): + def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=50, n_classes=2, cache: bool=True, *args, **kwargs): super().__init__() self.data_root = data_root self.label_path = label_path @@ -124,27 +134,29 @@ class MILDataModule(pl.LightningDataModule): self.num_bags_test = 50 self.seed = 1 - self.backbone = backbone self.cache = True self.fe_transform = None + def setup(self, stage: Optional[str] = None) -> None: home = Path.cwd().parts[1] if stage in (None, 'fit'): - dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes, backbone=self.backbone) + dataset = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes) a = int(len(dataset)* 0.8) b = int(len(dataset) - a) self.train_data, self.valid_data = random_split(dataset, [a, b]) if stage in (None, 'test'): - self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, backbone=self.backbone) + self.test_data = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, data_cache_size=1) return super().setup(stage=stage) + + def train_dataloader(self) -> DataLoader: - return DataLoader(self.train_data, self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, + return DataLoader(self.train_data, batch_size = self.batch_size, sampler=ImbalancedDatasetSampler(self.train_data), num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True, #sampler=ImbalancedDatasetSampler(self.train_data) def val_dataloader(self) -> DataLoader: return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers) @@ -187,13 +199,92 @@ class DataModule(pl.LightningDataModule): if stage in (None, 'test'): self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, backbone=self.backbone) + return super().setup(stage=stage) def train_dataloader(self) -> DataLoader: - return DataLoader(self.train_data, self.batch_size, num_workers=self.num_workers, shuffle=False,) #batch_transforms=self.transform, pseudo_batch_dim=True, + return DataLoader(self.train_data, self.batch_size, shuffle=False,) #batch_transforms=self.transform, pseudo_batch_dim=True, #sampler=ImbalancedDatasetSampler(self.train_data), def val_dataloader(self) -> DataLoader: - return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers) + return DataLoader(self.valid_data, batch_size = self.batch_size) + + def test_dataloader(self) -> DataLoader: + return DataLoader(self.test_data, batch_size = self.batch_size) #, num_workers=self.num_workers + + +class BaseKFoldDataModule(pl.LightningDataModule, ABC): + @abstractmethod + def setup_folds(self, num_folds: int) -> None: + pass + + @abstractmethod + def setup_fold_index(self, fold_index: int) -> None: + pass + +class CrossVal_MILDataModule(BaseKFoldDataModule): + + def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, backbone=None, *args, **kwargs): + super().__init__() + self.data_root = data_root + self.label_path = label_path + self.batch_size = batch_size + self.num_workers = num_workers + self.image_size = 384 + self.n_classes = n_classes + self.target_number = 9 + self.mean_bag_length = 10 + self.var_bag_length = 2 + self.num_bags_train = 200 + self.num_bags_test = 50 + self.seed = 1 + + self.backbone = backbone + self.cache = True + self.fe_transform = None + + # train_dataset: Optional[Dataset] = None + # test_dataset: Optional[Dataset] = None + # train_fold: Optional[Dataset] = None + # val_fold: Optional[Dataset] = None + self.train_data : Optional[Dataset] = None + self.test_data : Optional[Dataset] = None + self.train_fold : Optional[Dataset] = None + self.val_fold : Optional[Dataset] = None + + def setup(self, stage: Optional[str] = None) -> None: + home = Path.cwd().parts[1] + + # if stage in (None, 'fit'): + dataset = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes) + # a = int(len(dataset)* 0.8) + # b = int(len(dataset) - a) + # self.train_data, self.val_data = random_split(dataset, [a, b]) + self.train_data = dataset + + # if stage in (None, 'test'):, + self.test_data = JPGMILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes) + + # return super().setup(stage=stage) + + def setup_folds(self, num_folds: int) -> None: + self.num_folds = num_folds + self.splits = [split for split in KFold(num_folds).split(range(len(self.train_data)))] + + def setup_fold_index(self, fold_index: int) -> None: + train_indices, val_indices = self.splits[fold_index] + self.train_fold = Subset(self.train_data, train_indices) + self.val_fold = Subset(self.train_data, val_indices) + + + def train_dataloader(self) -> DataLoader: + return DataLoader(self.train_fold, self.batch_size, sampler=ImbalancedDatasetSampler(self.train_fold), num_workers=self.num_workers) #batch_transforms=self.transform, pseudo_batch_dim=True, + # return DataLoader(self.train_fold, self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, + #sampler=ImbalancedDatasetSampler(self.train_data) + def val_dataloader(self) -> DataLoader: + return DataLoader(self.val_fold, batch_size = self.batch_size, num_workers=self.num_workers) def test_dataloader(self) -> DataLoader: - return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers) \ No newline at end of file + return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers) + + + diff --git a/models/AttMIL.py b/models/AttMIL.py new file mode 100644 index 0000000000000000000000000000000000000000..d1e20eb3e5d465eab0008c5121b21f1fd9394605 --- /dev/null +++ b/models/AttMIL.py @@ -0,0 +1,79 @@ +import os +import logging +from argparse import ArgumentParser +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + +import pytorch_lightning as pl +from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR + + +class AttMIL(nn.Module): #gated attention + def __init__(self, n_classes, features=512): + super(AttMIL, self).__init__() + self.L = features + self.D = 128 + self.K = 1 + self.n_classes = n_classes + + # resnet50 = models.resnet50(pretrained=True) + # modules = list(resnet50.children())[:-3] + + # self.resnet_extractor = nn.Sequential( + # *modules, + # nn.AdaptiveAvgPool2d(1), + # View((-1, 1024)), + # nn.Linear(1024, self.L) + # ) + + # self.feature_extractor1 = nn.Sequential( + # nn.Conv2d(3, 20, kernel_size=5), + # nn.ReLU(), + # nn.MaxPool2d(2, stride=2), + # nn.Conv2d(20, 50, kernel_size=5), + # nn.ReLU(), + # nn.MaxPool2d(2, stride=2), + + # # View((-1, 50 * 4 * 4)), + # # nn.Linear(50 * 4 * 4, self.L), + # # nn.ReLU(), + # ) + + # self.feature_extractor_part2 = nn.Sequential( + # nn.Linear(50 * 4 * 4, self.L), + # nn.ReLU(), + # ) + + self.attention_V = nn.Sequential( + nn.Linear(self.L, self.D), + nn.Tanh() + ) + + self.attention_U = nn.Sequential( + nn.Linear(self.L, self.D), + nn.Sigmoid() + ) + + self.attention_weights = nn.Linear(self.D, self.K) + + self.classifier = nn.Sequential( + nn.Linear(self.L * self.K, self.n_classes), + ) + + def forward(self, x): + # H = kwargs['data'].float().squeeze(0) + H = x.float().squeeze(0) + A_V = self.attention_V(H) # NxD + A_U = self.attention_U(H) # NxD + A = self.attention_weights(A_V * A_U) # element wise multiplication # NxK + out_A = A + A = torch.transpose(A, 1, 0) # KxN + A = F.softmax(A, dim=1) # softmax over N + M = torch.mm(A, H) # KxL + logits = self.classifier(M) + + return logits \ No newline at end of file diff --git a/models/TransMIL.py b/models/TransMIL.py index 69089de0bd934ea38b2854f513295b21cbb73a03..ca2a1fbe8240fd1983b246cbbfbdad187c48b214 100755 --- a/models/TransMIL.py +++ b/models/TransMIL.py @@ -44,23 +44,23 @@ class PPEG(nn.Module): class TransMIL(nn.Module): - def __init__(self, n_classes): + def __init__(self, n_classes, in_features, out_features=384): super(TransMIL, self).__init__() - self.pos_layer = PPEG(dim=512) - self._fc1 = nn.Sequential(nn.Linear(512, 512), nn.ReLU()) + self.pos_layer = PPEG(dim=out_features) + self._fc1 = nn.Sequential(nn.Linear(in_features, out_features), nn.GELU()) # self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU()) - self.cls_token = nn.Parameter(torch.randn(1, 1, 512)) + self.cls_token = nn.Parameter(torch.randn(1, 1, out_features)) self.n_classes = n_classes - self.layer1 = TransLayer(dim=512) - self.layer2 = TransLayer(dim=512) - self.norm = nn.LayerNorm(512) - self._fc2 = nn.Linear(512, self.n_classes) + self.layer1 = TransLayer(dim=out_features) + self.layer2 = TransLayer(dim=out_features) + self.norm = nn.LayerNorm(out_features) + self._fc2 = nn.Linear(out_features, self.n_classes) - def forward(self, **kwargs): #, **kwargs + def forward(self, x): #, **kwargs - h = kwargs['data'].float() #[B, n, 1024] - # h = self._fc1(h) #[B, n, 512] + h = x.float() #[B, n, 1024] + h = self._fc1(h) #[B, n, 512] #---->pad H = h.shape[1] @@ -83,22 +83,22 @@ class TransMIL(nn.Module): h = self.layer2(h) #[B, N, 512] #---->cls_token + print(h.shape) #[1, 1025, 512] 1025 = cls_token + 1024 + + # tokens = h h = self.norm(h)[:,0] #---->predict logits = self._fc2(h) #[B, n_classes] - Y_hat = torch.argmax(logits, dim=1) - Y_prob = F.softmax(logits, dim = 1) - results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat} - return results_dict + return logits if __name__ == "__main__": data = torch.randn((1, 6000, 512)).cuda() - model = TransMIL(n_classes=2).cuda() + model = TransMIL(n_classes=2, in_features=512).cuda() print(model.eval()) - results_dict = model(data = data) + results_dict = model(data) print(results_dict) - logits = results_dict['logits'] - Y_prob = results_dict['Y_prob'] - Y_hat = results_dict['Y_hat'] + # logits = results_dict['logits'] + # Y_prob = results_dict['Y_prob'] + # Y_hat = results_dict['Y_hat'] # print(F.sigmoid(logits)) diff --git a/models/__pycache__/AttMIL.cpython-39.pyc b/models/__pycache__/AttMIL.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ee3af4e0d05559e5bc7a4d8ed62481a3254c9f1 Binary files /dev/null and b/models/__pycache__/AttMIL.cpython-39.pyc differ diff --git a/models/__pycache__/TransMIL.cpython-39.pyc b/models/__pycache__/TransMIL.cpython-39.pyc index cb0eb6fd5085feda9604ec0f24420d85eb1d939d..a329e23346d376b150c7ba792b8cd3593a128d95 100644 Binary files a/models/__pycache__/TransMIL.cpython-39.pyc and b/models/__pycache__/TransMIL.cpython-39.pyc differ diff --git a/models/__pycache__/model_interface.cpython-39.pyc b/models/__pycache__/model_interface.cpython-39.pyc index 0bf337675c76de4097e7f7723b99de0120f9594c..e9d22d7ddccbb2a3d2ab59b98e066f9ac5d42285 100644 Binary files a/models/__pycache__/model_interface.cpython-39.pyc and b/models/__pycache__/model_interface.cpython-39.pyc differ diff --git a/models/model_interface.py b/models/model_interface.py index 60b5cc73d6bec977216cdde6bb8bcee3fd267034..b3a561eba25656b93f4e4663c88e796022a556d0 100755 --- a/models/model_interface.py +++ b/models/model_interface.py @@ -7,6 +7,8 @@ import pandas as pd import seaborn as sns from pathlib import Path from matplotlib import pyplot as plt +import cv2 +from PIL import Image #----> from MyOptimizer import create_optimizer @@ -20,7 +22,10 @@ import torch import torch.nn as nn import torch.nn.functional as F import torchmetrics +from torchmetrics.functional import stat_scores from torch import optim as optim +# from sklearn.metrics import roc_curve, auc, roc_curve_score + #----> import pytorch_lightning as pl @@ -29,6 +34,10 @@ from torchvision import models from torchvision.models import resnet from transformers import AutoFeatureExtractor, ViTModel +from pytorch_grad_cam import GradCAM, EigenGradCAM +from pytorch_grad_cam.utils.image import show_cam_on_image +from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget + from captum.attr import LayerGradCam class ModelInterface(pl.LightningModule): @@ -41,11 +50,20 @@ class ModelInterface(pl.LightningModule): self.loss = create_loss(loss) # self.asl = AsymmetricLossSingleLabel() # self.loss = LabelSmoothingCrossEntropy(smoothing=0.1) - # self.loss = + # print(self.model) + + + # self.ecam = EigenGradCAM(model = self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform) self.optimizer = optimizer self.n_classes = model.n_classes - self.log_path = kargs['log'] + print(self.n_classes) + self.save_path = kargs['log'] + if Path(self.save_path).parts[3] == 'tcmr': + temp = list(Path(self.save_path).parts) + # print(temp) + temp[3] = 'tcmr_viral' + self.save_path = '/'.join(temp) #---->acc self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] @@ -53,6 +71,7 @@ class ModelInterface(pl.LightningModule): #---->Metrics if self.n_classes > 2: self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted') + metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes, average='micro'), torchmetrics.CohenKappa(num_classes = self.n_classes), @@ -67,6 +86,7 @@ class ModelInterface(pl.LightningModule): else : self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'weighted') + metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2, average = 'micro'), torchmetrics.CohenKappa(num_classes = 2), @@ -76,6 +96,8 @@ class ModelInterface(pl.LightningModule): num_classes = 2), torchmetrics.Precision(average = 'macro', num_classes = 2)]) + self.PRC = torchmetrics.PrecisionRecallCurve(num_classes = self.n_classes) + # self.pr_curve = torchmetrics.BinnedPrecisionRecallCurve(num_classes = self.n_classes, thresholds=10) self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes) self.valid_metrics = metrics.clone(prefix = 'val_') self.test_metrics = metrics.clone(prefix = 'test_') @@ -146,28 +168,40 @@ class ModelInterface(pl.LightningModule): nn.Linear(1024, self.out_features), nn.ReLU(), ) + # print(self.model_ft[0].features[-1]) + # print(self.model_ft) + if model.name == 'TransMIL': + target_layers = [self.model.layer2.norm] # 32x32 + # target_layers = [self.model_ft[0].features[-1]] # 32x32 + self.cam = GradCAM(model=self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform) #, reshape_transform=self.reshape_transform + # self.cam_ft = GradCAM(model=self.model, target_layers = target_layers_ft, use_cuda=True) #, reshape_transform=self.reshape_transform + else: + target_layers = [self.model.attention_weights] + self.cam = GradCAM(model = self.model, target_layers = target_layers, use_cuda=True) + + def forward(self, x): + + feats = self.model_ft(x).unsqueeze(0) + return self.model(feats) + + def step(self, input): + + input = input.squeeze(0).float() + logits = self(input) + + Y_hat = torch.argmax(logits, dim=1) + Y_prob = F.softmax(logits, dim=1) + + return logits, Y_prob, Y_hat def training_step(self, batch, batch_idx): #---->inference - data, label, _ = batch + + input, label, _= batch label = label.float() - data = data.squeeze(0).float() - # print(data) - # print(data.shape) - if self.backbone == 'dino': - features = self.model_ft(**data) - features = features.last_hidden_state - else: - features = self.model_ft(data) - features = features.unsqueeze(0) - # print(features.shape) - # features = features.squeeze() - results_dict = self.model(data=features) - # results_dict = self.model(data=data, label=label) - logits = results_dict['logits'] - Y_prob = results_dict['Y_prob'] - Y_hat = results_dict['Y_hat'] + + logits, Y_prob, Y_hat = self.step(input) #---->loss loss = self.loss(logits, label) @@ -183,6 +217,14 @@ class ModelInterface(pl.LightningModule): # Y = int(label[0]) self.data[Y]["count"] += 1 self.data[Y]["correct"] += (int(Y_hat) == Y) + self.log('loss', loss, prog_bar=True, on_epoch=True, logger=True) + + if self.current_epoch % 10 == 0: + + grid = torchvision.utils.make_grid(images) + # log input images + # self.loggers[0].experiment.add_figure(f'{stage}/input', , self.current_epoch) + return {'loss': loss, 'Y_prob': Y_prob, 'Y_hat': Y_hat, 'label': label} @@ -212,18 +254,10 @@ class ModelInterface(pl.LightningModule): def validation_step(self, batch, batch_idx): - data, label, _ = batch - + input, label, _ = batch label = label.float() - data = data.squeeze(0).float() - features = self.model_ft(data) - features = features.unsqueeze(0) - - results_dict = self.model(data=features) - logits = results_dict['logits'] - Y_prob = results_dict['Y_prob'] - Y_hat = results_dict['Y_hat'] - + + logits, Y_prob, Y_hat = self.step(input) #---->acc log # Y = int(label[0][1]) @@ -237,18 +271,23 @@ class ModelInterface(pl.LightningModule): def validation_epoch_end(self, val_step_outputs): logits = torch.cat([x['logits'] for x in val_step_outputs], dim = 0) - # probs = torch.cat([x['Y_prob'] for x in val_step_outputs], dim = 0) probs = torch.cat([x['Y_prob'] for x in val_step_outputs]) max_probs = torch.stack([x['Y_hat'] for x in val_step_outputs]) - # target = torch.stack([x['label'] for x in val_step_outputs], dim = 0) target = torch.cat([x['label'] for x in val_step_outputs]) target = torch.argmax(target, dim=1) #----> # logits = logits.long() # target = target.squeeze().long() # logits = logits.squeeze(0) + if len(target.unique()) != 1: + self.log('val_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True) + else: + self.log('val_auc', 0.0, prog_bar=True, on_epoch=True, logger=True) + + + self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True) - self.log('val_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True) + # print(max_probs.squeeze(0).shape) # print(target.shape) @@ -276,24 +315,61 @@ class ModelInterface(pl.LightningModule): random.seed(self.count*50) def test_step(self, batch, batch_idx): - - data, label, _ = batch + torch.set_grad_enabled(True) + data, label, name = batch label = label.float() + # logits, Y_prob, Y_hat = self.step(data) + # print(data.shape) data = data.squeeze(0).float() - features = self.model_ft(data) - features = features.unsqueeze(0) + logits = self(data).detach() + + Y = torch.argmax(label) + Y_hat = torch.argmax(logits, dim=1) + Y_prob = F.softmax(logits, dim = 1) + + #----> Get Topk tiles + + target = [ClassifierOutputTarget(Y)] + + data_ft = self.model_ft(data).unsqueeze(0).float() + # data_ft = self.model_ft(data).unsqueeze(0).float() + # print(data_ft.shape) + # print(target) + grayscale_cam = self.cam(input_tensor=data_ft, targets=target) + # grayscale_ecam = self.ecam(input_tensor=data_ft, targets=target) + + # print(grayscale_cam) - results_dict = self.model(data=features, label=label) - logits = results_dict['logits'] - Y_prob = results_dict['Y_prob'] - Y_hat = results_dict['Y_hat'] + summed = torch.mean(torch.Tensor(grayscale_cam), dim=2) + print(summed) + print(summed.shape) + topk_tiles, topk_indices = torch.topk(summed.squeeze(0), 5, dim=0) + topk_data = data[topk_indices].detach() + + # target_ft = + # grayscale_cam_ft = self.cam_ft(input_tensor=data, ) + # for i in range(data.shape[0]): + + # vis_img = data[i, :, :, :].cpu().numpy() + # vis_img = np.transpose(vis_img, (1,2,0)) + # print(vis_img.shape) + # cam_img = grayscale_cam.squeeze(0) + # cam_img = self.reshape_transform(grayscale_cam) + + # print(cam_img.shape) + + # visualization = show_cam_on_image(vis_img, cam_img, use_rgb=True) + # visualization = ((visualization/visualization.max())*255.0).astype(np.uint8) + # print(visualization) + # cv2.imwrite(f'{test_path}/{Y}/{name}/gradcam.jpg', cam_img) #---->acc log Y = torch.argmax(label) self.data[Y]["count"] += 1 self.data[Y]["correct"] += (Y_hat.item() == Y) - return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label} + return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': name, 'topk_data': topk_data} # + # return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': name} #, 'topk_data': topk_data def test_epoch_end(self, output_results): probs = torch.cat([x['Y_prob'] for x in output_results]) @@ -301,7 +377,8 @@ class ModelInterface(pl.LightningModule): # target = torch.stack([x['label'] for x in output_results], dim = 0) target = torch.cat([x['label'] for x in output_results]) target = torch.argmax(target, dim=1) - + patients = [x['name'] for x in output_results] + topk_tiles = [x['topk_data'] for x in output_results] #----> auc = self.AUROC(probs, target.squeeze()) metrics = self.test_metrics(max_probs.squeeze() , target) @@ -312,9 +389,41 @@ class ModelInterface(pl.LightningModule): # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True) - # print(max_probs.squeeze(0).shape) - # print(target.shape) - # self.log_dict(metrics, logger = True) + #---->get highest scoring patients for each class + test_path = Path(self.save_path) / 'most_predictive' + topk, topk_indices = torch.topk(probs.squeeze(0), 5, dim=0) + for n in range(self.n_classes): + print('class: ', n) + topk_patients = [patients[i[n]] for i in topk_indices] + topk_patient_tiles = [topk_tiles[i[n]] for i in topk_indices] + for x, p, t in zip(topk, topk_patients, topk_patient_tiles): + print(p, x[n]) + patient = p[0] + outpath = test_path / str(n) / patient + outpath.mkdir(parents=True, exist_ok=True) + for i in range(len(t)): + tile = t[i] + tile = tile.cpu().numpy().transpose(1,2,0) + tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255 + tile = tile.astype(np.uint8) + img = Image.fromarray(tile) + + img.save(f'{test_path}/{n}/{patient}/{i}_gradcam.jpg') + + + + #----->visualize top predictive tiles + + + + + # img = img.squeeze(0).cpu().numpy() + # img = np.transpose(img, (1,2,0)) + # # print(img) + # # print(grayscale_cam.shape) + # visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True) + + for keys, values in metrics.items(): print(f'{keys} = {values}') metrics[keys] = values.cpu().numpy() @@ -329,16 +438,35 @@ class ModelInterface(pl.LightningModule): print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count)) self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] + #---->plot auroc curve + # stats = stat_scores(probs, target, reduce='macro', num_classes=self.n_classes) + # fpr = {} + # tpr = {} + # for n in self.n_classes: + + # fpr, tpr, thresh = roc_curve(target.cpu().numpy(), probs.cpu().numpy()) + #[tp, fp, tn, fn, tp+fn] + + self.log_confusion_matrix(probs, target, stage='test') #----> result = pd.DataFrame([metrics]) - result.to_csv(self.log_path / 'result.csv') + result.to_csv(Path(self.save_path) / f'test_result.csv', mode='a', header=not Path(self.save_path).exists()) + + # with open(f'{self.save_path}/test_metrics.txt', 'a') as f: + + # f.write([metrics]) def configure_optimizers(self): # optimizer_ft = optim.Adam(self.model_ft.parameters(), lr=self.optimizer.lr*0.1) optimizer = create_optimizer(self.optimizer, self.model) return optimizer + def reshape_transform(self, tensor, h=32, w=32): + result = tensor[:, 1:, :].reshape(tensor.size(0), h, w, tensor.size(2)) + result = result.transpose(2,3).transpose(1,2) + # print(result.shape) + return result def load_model(self): name = self.hparams.model.name @@ -372,18 +500,33 @@ class ModelInterface(pl.LightningModule): args1.update(other_args) return Model(**args1) + def log_image(self, tensor, stage, name): + + tile = tile.cpu().numpy().transpose(1,2,0) + tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255 + tile = tile.astype(np.uint8) + img = Image.fromarray(tile) + self.loggers[0].experiment.add_figure(f'{stage}/{name}', img, self.current_epoch) + + def log_confusion_matrix(self, max_probs, target, stage): confmat = self.confusion_matrix(max_probs.squeeze(), target) + print(confmat) df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes)) - plt.figure() + # plt.figure() fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure() # plt.close(fig_) - # plt.savefig(f'{self.log_path}/cm_e{self.current_epoch}') - self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch) + # plt.savefig(f'{self.save_path}/cm_e{self.current_epoch}') + - if stage == 'test': - plt.savefig(f'{self.log_path}/cm_test') - plt.close(fig_) + if stage == 'train': + # print(self.save_path) + # plt.savefig(f'{self.save_path}/cm_test') + + self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch) + else: + fig_.savefig(f'{self.save_path}/cm_test.png', dpi=400) + # plt.close(fig_) # self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch) class View(nn.Module): diff --git a/test_visualize.py b/test_visualize.py new file mode 100644 index 0000000000000000000000000000000000000000..7ef56d3fec1a0c3d926428e95b051c13bdb1da96 --- /dev/null +++ b/test_visualize.py @@ -0,0 +1,148 @@ +import argparse +from pathlib import Path +import numpy as np +import glob + +from sklearn.model_selection import KFold + +from datasets.data_interface import DataInterface, MILDataModule, CrossVal_MILDataModule +from models.model_interface import ModelInterface +import models.vision_transformer as vits +from utils.utils import * + +# pytorch_lightning +import pytorch_lightning as pl +from pytorch_lightning import Trainer +import torch +from train_loop import KFoldLoop + +#--->Setting parameters +def make_parse(): + parser = argparse.ArgumentParser() + parser.add_argument('--stage', default='train', type=str) + parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str) + parser.add_argument('--version', default=0,type=int) + parser.add_argument('--epoch', default='0',type=str) + parser.add_argument('--gpus', default = 2, type=int) + parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str) + parser.add_argument('--fold', default = 0) + parser.add_argument('--bag_size', default = 1024, type=int) + + args = parser.parse_args() + return args + +#---->main +def main(cfg): + + torch.set_num_threads(16) + + #---->Initialize seed + pl.seed_everything(cfg.General.seed) + + #---->load loggers + # cfg.load_loggers = load_loggers(cfg) + + # print(cfg.load_loggers) + # save_path = Path(cfg.load_loggers[0].log_dir) + + #---->load callbacks + # cfg.callbacks = load_callbacks(cfg, save_path) + + home = Path.cwd().parts[1] + DataInterface_dict = { + 'data_root': cfg.Data.data_dir, + 'label_path': cfg.Data.label_file, + 'batch_size': cfg.Data.train_dataloader.batch_size, + 'num_workers': cfg.Data.train_dataloader.num_workers, + 'n_classes': cfg.Model.n_classes, + 'backbone': cfg.Model.backbone, + 'bag_size': cfg.Data.bag_size, + } + + dm = MILDataModule(**DataInterface_dict) + + + #---->Define Model + ModelInterface_dict = {'model': cfg.Model, + 'loss': cfg.Loss, + 'optimizer': cfg.Optimizer, + 'data': cfg.Data, + 'log': cfg.log_path, + 'backbone': cfg.Model.backbone, + } + model = ModelInterface(**ModelInterface_dict) + + #---->Instantiate Trainer + trainer = Trainer( + num_sanity_val_steps=0, + # logger=cfg.load_loggers, + # callbacks=cfg.callbacks, + max_epochs= cfg.General.epochs, + min_epochs = 200, + gpus=cfg.General.gpus, + # gpus = [0,2], + # strategy='ddp', + amp_backend='native', + # amp_level=cfg.General.amp_level, + precision=cfg.General.precision, + accumulate_grad_batches=cfg.General.grad_acc, + # fast_dev_run = True, + + # deterministic=True, + check_val_every_n_epoch=10, + ) + + #---->train or test + log_path = cfg.log_path + # print(log_path) + # log_path = Path('lightning_logs/2/checkpoints') + model_paths = list(log_path.glob('*.ckpt')) + + if cfg.epoch == 'last': + model_paths = [str(model_path) for model_path in model_paths if f'last' in str(model_path)] + else: + model_paths = [str(model_path) for model_path in model_paths if f'epoch={cfg.epoch}' in str(model_path)] + + # model_paths = [str(model_path) for model_path in model_paths if f'epoch={cfg.epoch}' in str(model_path)] + # model_paths = [f'lightning_logs/0/.ckpt'] + # model_paths = [f'{log_path}/last.ckpt'] + if not model_paths: + print('No Checkpoints vailable!') + for path in model_paths: + # with open(f'{log_path}/test_metrics.txt', 'w') as f: + # f.write(str(path) + '\n') + print(path) + new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg) + trainer.test(model=new_model, datamodule=dm) + + # Top 5 scoring patches for patient + # GradCam + + +if __name__ == '__main__': + + args = make_parse() + cfg = read_yaml(args.config) + + #---->update + cfg.config = args.config + cfg.General.gpus = [args.gpus] + cfg.General.server = args.stage + cfg.Data.fold = args.fold + cfg.Loss.base_loss = args.loss + cfg.Data.bag_size = args.bag_size + cfg.version = args.version + cfg.epoch = args.epoch + + log_path = Path(cfg.General.log_path) / str(Path(cfg.config).parent) + Path(cfg.General.log_path).mkdir(exist_ok=True, parents=True) + log_name = f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}' + task = '_'.join(Path(cfg.config).name[:-5].split('_')[2:]) + # task = Path(cfg.config).name[:-5].split('_')[2:][0] + cfg.log_path = log_path / f'{cfg.Model.name}' / task / log_name / 'lightning_logs' / f'version_{cfg.version}' / 'checkpoints' + + + + #---->main + main(cfg) + \ No newline at end of file diff --git a/train.py b/train.py index 036d5ed183d4c8ee047ff3413bfb3d4730154b19..5e30394352f771017d43e31cc8585ed4ef9aeb07 100644 --- a/train.py +++ b/train.py @@ -5,7 +5,7 @@ import glob from sklearn.model_selection import KFold -from datasets.data_interface import DataInterface, MILDataModule +from datasets.data_interface import DataInterface, MILDataModule, CrossVal_MILDataModule from models.model_interface import ModelInterface import models.vision_transformer as vits from utils.utils import * @@ -13,18 +13,23 @@ from utils.utils import * # pytorch_lightning import pytorch_lightning as pl from pytorch_lightning import Trainer -from pytorch_lightning.loops import KFoldLoop import torch +from train_loop import KFoldLoop #--->Setting parameters def make_parse(): parser = argparse.ArgumentParser() parser.add_argument('--stage', default='train', type=str) parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str) + parser.add_argument('--version', default=2,type=int) parser.add_argument('--gpus', default = 2, type=int) parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str) parser.add_argument('--fold', default = 0) parser.add_argument('--bag_size', default = 1024, type=int) + parser.add_argument('--resume_training', action='store_true') + # parser.add_argument('--ckpt_path', default = , type=str) + + args = parser.parse_args() return args @@ -39,9 +44,10 @@ def main(cfg): #---->load loggers cfg.load_loggers = load_loggers(cfg) # print(cfg.load_loggers) + save_path = Path(cfg.load_loggers[0].log_dir) #---->load callbacks - cfg.callbacks = load_callbacks(cfg) + cfg.callbacks = load_callbacks(cfg, save_path) #---->Define Data # DataInterface_dict = {'train_batch_size': cfg.Data.train_dataloader.batch_size, @@ -58,11 +64,12 @@ def main(cfg): 'batch_size': cfg.Data.train_dataloader.batch_size, 'num_workers': cfg.Data.train_dataloader.num_workers, 'n_classes': cfg.Model.n_classes, - 'backbone': cfg.Model.backbone, 'bag_size': cfg.Data.bag_size, } - dm = MILDataModule(**DataInterface_dict) + if cfg.Data.cross_val: + dm = CrossVal_MILDataModule(**DataInterface_dict) + else: dm = MILDataModule(**DataInterface_dict) #---->Define Model @@ -82,9 +89,9 @@ def main(cfg): callbacks=cfg.callbacks, max_epochs= cfg.General.epochs, min_epochs = 200, - # gpus=cfg.General.gpus, - gpus = [2,3], - strategy='ddp', + gpus=cfg.General.gpus, + # gpus = [0,2], + # strategy='ddp', amp_backend='native', # amp_level=cfg.General.amp_level, precision=cfg.General.precision, @@ -96,12 +103,31 @@ def main(cfg): ) #---->train or test + if cfg.resume_training: + last_ckpt = log_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.version}' / 'last.ckpt' + trainer.fit(model = model, datamodule = dm, ckpt_path=last_ckpt) + if cfg.General.server == 'train': - trainer.fit_loop = KFoldLoop(3, trainer.fit_loop, ) - trainer.fit(model = model, datamodule = dm) + + # k-fold cross validation loop + if cfg.Data.cross_val: + internal_fit_loop = trainer.fit_loop + trainer.fit_loop = KFoldLoop(cfg.Data.nfold, export_path = cfg.log_path, **ModelInterface_dict) + trainer.fit_loop.connect(internal_fit_loop) + trainer.fit(model, dm) + else: + trainer.fit(model = model, datamodule = dm) else: - model_paths = list(cfg.log_path.glob('*.ckpt')) + log_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.version}' + + test_path = Path(log_path) / 'test' + for n in range(cfg.Model.n_classes): + n_output_path = test_path / str(n) + n_output_path.mkdir(parents=True, exist_ok=True) + # print(cfg.log_path) + model_paths = list(log_path.glob('*.ckpt')) model_paths = [str(model_path) for model_path in model_paths if 'epoch' in str(model_path)] + # model_paths = [f'{log_path}/epoch=279-val_loss=0.4009.ckpt'] for path in model_paths: print(path) new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg) @@ -120,6 +146,16 @@ if __name__ == '__main__': cfg.Data.fold = args.fold cfg.Loss.base_loss = args.loss cfg.Data.bag_size = args.bag_size + cfg.version = args.version + + log_path = Path(cfg.General.log_path) / str(Path(cfg.config).parent) + Path(cfg.General.log_path).mkdir(exist_ok=True, parents=True) + log_name = f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}' + task = '_'.join(Path(cfg.config).name[:-5].split('_')[2:]) + # task = Path(cfg.config).name[:-5].split('_')[2:][0] + cfg.log_path = log_path / f'{cfg.Model.name}' / task / log_name + + #---->main diff --git a/train_loop.py b/train_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..f9236814589b30d30aab061aaf52599db5a130df --- /dev/null +++ b/train_loop.py @@ -0,0 +1,212 @@ +from pytorch_lightning import LightningModule +import torch +import torch.nn.functional as F +from torchmetrics.classification.accuracy import Accuracy +import os.path as osp +from abc import ABC, abstractmethod +from copy import deepcopy +from pytorch_lightning import LightningModule +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.loops.fit_loop import FitLoop +from pytorch_lightning.trainer.states import TrainerFn +from datasets.data_interface import BaseKFoldDataModule +from typing import Any, Dict, List, Optional, Type +import torchmetrics +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns + + + +class EnsembleVotingModel(LightningModule): + def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str], n_classes, log_path) -> None: + super().__init__() + # Create `num_folds` models with their associated fold weights + self.n_classes = n_classes + self.log_path = log_path + self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] + self.models = torch.nn.ModuleList([model_cls.load_from_checkpoint(p) for p in checkpoint_paths]) + self.test_acc = Accuracy() + if self.n_classes > 2: + self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted') + metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes, + average='micro'), + torchmetrics.CohenKappa(num_classes = self.n_classes), + torchmetrics.F1Score(num_classes = self.n_classes, + average = 'macro'), + torchmetrics.Recall(average = 'macro', + num_classes = self.n_classes), + torchmetrics.Precision(average = 'macro', + num_classes = self.n_classes), + torchmetrics.Specificity(average = 'macro', + num_classes = self.n_classes)]) + + else : + self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'weighted') + metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2, + average = 'micro'), + torchmetrics.CohenKappa(num_classes = 2), + torchmetrics.F1Score(num_classes = 2, + average = 'macro'), + torchmetrics.Recall(average = 'macro', + num_classes = 2), + torchmetrics.Precision(average = 'macro', + num_classes = 2)]) + self.test_metrics = metrics.clone(prefix = 'test_') + self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes) + + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + # Compute the averaged predictions over the `num_folds` models. + # print(batch[0].shape) + input, label, _ = batch + label = label.float() + input = input.squeeze(0).float() + + + logits = torch.stack([m(input) for m in self.models]).mean(0) + Y_hat = torch.argmax(logits, dim=1) + Y_prob = F.softmax(logits, dim = 1) + # #---->acc log + Y = torch.argmax(label) + self.data[Y]["count"] += 1 + self.data[Y]["correct"] += (Y_hat.item() == Y) + + return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label} + + def test_epoch_end(self, output_results): + probs = torch.cat([x['Y_prob'] for x in output_results]) + max_probs = torch.stack([x['Y_hat'] for x in output_results]) + # target = torch.stack([x['label'] for x in output_results], dim = 0) + target = torch.cat([x['label'] for x in output_results]) + target = torch.argmax(target, dim=1) + + #----> + auc = self.AUROC(probs, target.squeeze()) + metrics = self.test_metrics(max_probs.squeeze() , target) + + + # metrics = self.test_metrics(max_probs.squeeze() , torch.argmax(target.squeeze(), dim=1)) + metrics['test_auc'] = auc + + # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True) + + # print(max_probs.squeeze(0).shape) + # print(target.shape) + # self.log_dict(metrics, logger = True) + for keys, values in metrics.items(): + print(f'{keys} = {values}') + metrics[keys] = values.cpu().numpy() + #---->acc log + for c in range(self.n_classes): + count = self.data[c]["count"] + correct = self.data[c]["correct"] + if count == 0: + acc = None + else: + acc = float(correct) / count + print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count)) + self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] + + self.log_confusion_matrix(probs, target, stage='test') + #----> + result = pd.DataFrame([metrics]) + result.to_csv(self.log_path / 'result.csv') + + + def log_confusion_matrix(self, max_probs, target, stage): + confmat = self.confusion_matrix(max_probs.squeeze(), target) + df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes)) + plt.figure() + fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure() + # plt.close(fig_) + # plt.savefig(f'{self.log_path}/cm_e{self.current_epoch}') + self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch) + + if stage == 'test': + plt.savefig(f'{self.log_path}/cm_test') + plt.close(fig_) + +class KFoldLoop(Loop): + def __init__(self, num_folds: int, export_path: str, **kargs) -> None: + super().__init__() + self.num_folds = num_folds + self.current_fold: int = 0 + self.export_path = export_path + self.n_classes = kargs["model"].n_classes + self.log_path = kargs["log"] + + @property + def done(self) -> bool: + return self.current_fold >= self.num_folds + + def connect(self, fit_loop: FitLoop) -> None: + self.fit_loop = fit_loop + + def reset(self) -> None: + """Nothing to reset in this loop.""" + + def on_run_start(self, *args: Any, **kwargs: Any) -> None: + """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the + model.""" + assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) + self.trainer.datamodule.setup_folds(self.num_folds) + self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict()) + + def on_advance_start(self, *args: Any, **kwargs: Any) -> None: + """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance.""" + print(f"STARTING FOLD {self.current_fold}") + assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) + self.trainer.datamodule.setup_fold_index(self.current_fold) + + def advance(self, *args: Any, **kwargs: Any) -> None: + """Used to the run a fitting and testing on the current hold.""" + self._reset_fitting() # requires to reset the tracking stage. + self.fit_loop.run() + + self._reset_testing() # requires to reset the tracking stage. + self.trainer.test_loop.run() + self.current_fold += 1 # increment fold tracking number. + + def on_advance_end(self) -> None: + """Used to save the weights of the current fold and reset the LightningModule and its optimizers.""" + self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt")) + # restore the original weights + optimizers and schedulers. + self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict) + self.trainer.strategy.setup_optimizers(self.trainer) + self.replace(fit_loop=FitLoop) + + def on_run_end(self) -> None: + """Used to compute the performance of the ensemble model on the test set.""" + checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)] + voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths, n_classes=self.n_classes, log_path=self.log_path) + voting_model.trainer = self.trainer + # This requires to connect the new model and move it the right device. + self.trainer.strategy.connect(voting_model) + self.trainer.strategy.model_to_device() + self.trainer.test_loop.run() + + def on_save_checkpoint(self) -> Dict[str, int]: + return {"current_fold": self.current_fold} + + def on_load_checkpoint(self, state_dict: Dict) -> None: + self.current_fold = state_dict["current_fold"] + + def _reset_fitting(self) -> None: + self.trainer.reset_train_dataloader() + self.trainer.reset_val_dataloader() + self.trainer.state.fn = TrainerFn.FITTING + self.trainer.training = True + + def _reset_testing(self) -> None: + self.trainer.reset_test_dataloader() + self.trainer.state.fn = TrainerFn.TESTING + self.trainer.testing = True + + def __getattr__(self, key) -> Any: + # requires to be overridden as attributes of the wrapped loop are being accessed. + if key not in self.__dict__: + return getattr(self.fit_loop, key) + return self.__dict__[key] + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__dict__.update(state) \ No newline at end of file diff --git a/utils/__pycache__/utils.cpython-39.pyc b/utils/__pycache__/utils.cpython-39.pyc index 704aade8ba3e8bcbedf237b3fde5db80b84fcd32..33d3d005a578948dd3d5b58938a815f86c4e92f5 100644 Binary files a/utils/__pycache__/utils.cpython-39.pyc and b/utils/__pycache__/utils.cpython-39.pyc differ diff --git a/utils/utils.py b/utils/utils.py index 010eaab0c51e6fc784774502d85e7eef7303487e..814cf1df0d1258fa19663208e91cd23ece06f6eb 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -10,10 +10,11 @@ from torch.utils.data.dataset import Dataset, Subset from torchmetrics.classification.accuracy import Accuracy from pytorch_lightning import LightningDataModule, seed_everything, Trainer -from pytorch_lightning.core.module import LightningModule +from pytorch_lightning import LightningModule from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.trainer.states import TrainerFn +from typing import Any, Dict, List, Optional, Type #---->read yaml import yaml @@ -27,30 +28,30 @@ def read_yaml(fpath=None): from pytorch_lightning import loggers as pl_loggers def load_loggers(cfg): - log_path = cfg.General.log_path - Path(log_path).mkdir(exist_ok=True, parents=True) - log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}' - version_name = Path(cfg.config).name[:-5] + # log_path = cfg.General.log_path + # Path(log_path).mkdir(exist_ok=True, parents=True) + # log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}' + # version_name = Path(cfg.config).name[:-5] #---->TensorBoard if cfg.stage != 'test': - cfg.log_path = Path(log_path) / log_name / version_name / f'fold{cfg.Data.fold}' - tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name), - name = version_name, version = f'fold{cfg.Data.fold}', + + tb_logger = pl_loggers.TensorBoardLogger(cfg.log_path, + # version = f'fold{cfg.Data.fold}' log_graph = True, default_hp_metric = False) #---->CSV - csv_logger = pl_loggers.CSVLogger(log_path+str(log_name), - name = version_name, version = f'fold{cfg.Data.fold}', ) + csv_logger = pl_loggers.CSVLogger(cfg.log_path, + ) # version = f'fold{cfg.Data.fold}', else: - cfg.log_path = Path(log_path) / log_name / version_name / f'test' - tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name), - name = version_name, version = f'test', + cfg.log_path = Path(cfg.log_path) / f'test' + tb_logger = pl_loggers.TensorBoardLogger(cfg.log_path, + version = f'test', log_graph = True, default_hp_metric = False) #---->CSV - csv_logger = pl_loggers.CSVLogger(log_path+str(log_name), - name = version_name, version = f'test', ) - + csv_logger = pl_loggers.CSVLogger(cfg.log_path, + version = f'test', ) + print(f'---->Log dir: {cfg.log_path}') @@ -63,11 +64,11 @@ from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme from pytorch_lightning.callbacks.early_stopping import EarlyStopping -def load_callbacks(cfg): +def load_callbacks(cfg, save_path): Mycallbacks = [] # Make output path - output_path = cfg.log_path + output_path = save_path / 'checkpoints' output_path.mkdir(exist_ok=True, parents=True) early_stop_callback = EarlyStopping( @@ -94,8 +95,9 @@ def load_callbacks(cfg): Mycallbacks.append(progress_bar) if cfg.General.server == 'train' : + # save_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.resume_version}' / last.ckpt Mycallbacks.append(ModelCheckpoint(monitor = 'val_loss', - dirpath = str(cfg.log_path), + dirpath = str(output_path), filename = '{epoch:02d}-{val_loss:.4f}', verbose = True, save_last = True, @@ -103,7 +105,7 @@ def load_callbacks(cfg): mode = 'min', save_weights_only = True)) Mycallbacks.append(ModelCheckpoint(monitor = 'val_auc', - dirpath = str(cfg.log_path), + dirpath = str(output_path), filename = '{epoch:02d}-{val_auc:.4f}', verbose = True, save_last = True, @@ -136,87 +138,3 @@ def convert_labels_for_task(task, label): return label_map[task][label] -#-----> KFOLD LOOP - -class KFoldLoop(Loop): - def __init__(self, num_folds: int, export_path: str) -> None: - super().__init__() - self.num_folds = num_folds - self.current_fold: int = 0 - self.export_path = export_path - - @property - def done(self) -> bool: - return self.current_fold >= self.num_folds - - def connect(self, fit_loop: FitLoop) -> None: - self.fit_loop = fit_loop - - def reset(self) -> None: - """Nothing to reset in this loop.""" - - def on_run_start(self, *args: Any, **kwargs: Any) -> None: - """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the - model.""" - assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) - self.trainer.datamodule.setup_folds(self.num_folds) - self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict()) - - def on_advance_start(self, *args: Any, **kwargs: Any) -> None: - """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance.""" - print(f"STARTING FOLD {self.current_fold}") - assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) - self.trainer.datamodule.setup_fold_index(self.current_fold) - - def advance(self, *args: Any, **kwargs: Any) -> None: - """Used to the run a fitting and testing on the current hold.""" - self._reset_fitting() # requires to reset the tracking stage. - self.fit_loop.run() - - self._reset_testing() # requires to reset the tracking stage. - self.trainer.test_loop.run() - self.current_fold += 1 # increment fold tracking number. - - def on_advance_end(self) -> None: - """Used to save the weights of the current fold and reset the LightningModule and its optimizers.""" - self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt")) - # restore the original weights + optimizers and schedulers. - self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict) - self.trainer.strategy.setup_optimizers(self.trainer) - self.replace(fit_loop=FitLoop) - - def on_run_end(self) -> None: - """Used to compute the performance of the ensemble model on the test set.""" - checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)] - voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths) - voting_model.trainer = self.trainer - # This requires to connect the new model and move it the right device. - self.trainer.strategy.connect(voting_model) - self.trainer.strategy.model_to_device() - self.trainer.test_loop.run() - - def on_save_checkpoint(self) -> Dict[str, int]: - return {"current_fold": self.current_fold} - - def on_load_checkpoint(self, state_dict: Dict) -> None: - self.current_fold = state_dict["current_fold"] - - def _reset_fitting(self) -> None: - self.trainer.reset_train_dataloader() - self.trainer.reset_val_dataloader() - self.trainer.state.fn = TrainerFn.FITTING - self.trainer.training = True - - def _reset_testing(self) -> None: - self.trainer.reset_test_dataloader() - self.trainer.state.fn = TrainerFn.TESTING - self.trainer.testing = True - - def __getattr__(self, key) -> Any: - # requires to be overridden as attributes of the wrapped loop are being accessed. - if key not in self.__dict__: - return getattr(self.fit_loop, key) - return self.__dict__[key] - - def __setstate__(self, state: Dict[str, Any]) -> None: - self.__dict__.update(state) \ No newline at end of file