Skip to content
Snippets Groups Projects
Commit 54cf0ab3 authored by Ycblue's avatar Ycblue
Browse files

First Commit

parent 3f6bbe86
No related branches found
No related tags found
No related merge requests found
Showing with 1116 additions and 60 deletions
import pandas as pd
import numpy as np
import torch
from torch import Tensor
from torch.autograd import Variable
from torch.nn.functional import one_hot
import torch.utils.data as data_utils
from torchvision import datasets, transforms
import pandas as pd
from sklearn.utils import shuffle
from pathlib import Path
from tqdm import tqdm
class FeatureBagLoader(data_utils.Dataset):
def __init__(self, data_root,train=True, cache=True):
bags_path = pd.read_csv(data_root)
self.train_path = bags_path.iloc[0:int(len(bags_path)*0.8), :]
self.test_path = bags_path.iloc[int(len(bags_path)*0.8):, :]
# self.train_path = shuffle(train_path).reset_index(drop=True)
# self.test_path = shuffle(test_path).reset_index(drop=True)
home = Path.cwd().parts[1]
self.origin_path = Path(f'/{home}/ylan/RCC_project/rcc_classification/')
# self.target_number = target_number
# self.mean_bag_length = mean_bag_length
# self.var_bag_length = var_bag_length
# self.num_bag = num_bag
self.cache = cache
self.train = train
self.n_classes = 2
self.features = []
self.labels = []
if self.cache:
if train:
with tqdm(total=len(self.train_path)) as pbar:
for t in tqdm(self.train_path.iloc()):
ft, lbl = self.get_bag_feats(t)
# ft = ft.view(-1, 512)
self.labels.append(lbl)
self.features.append(ft)
pbar.update()
else:
with tqdm(total=len(self.test_path)) as pbar:
for t in tqdm(self.test_path.iloc()):
ft, lbl = self.get_bag_feats(t)
# lbl = Variable(Tensor(lbl))
# ft = Variable(Tensor(ft)).view(-1, 512)
self.labels.append(lbl)
self.features.append(ft)
pbar.update()
# print(self.get_bag_feats(self.train_path))
# self.r = np.random.RandomState(seed)
# self.num_in_train = 60000
# self.num_in_test = 10000
# if self.train:
# self.train_bags_list, self.train_labels_list = self._create_bags()
# else:
# self.test_bags_list, self.test_labels_list = self._create_bags()
def get_bag_feats(self, csv_file_df):
# if args.dataset == 'TCGA-lung-default':
# feats_csv_path = 'datasets/tcga-dataset/tcga_lung_data_feats/' + csv_file_df.iloc[0].split('/')[1] + '.csv'
# else:
feats_csv_path = self.origin_path / csv_file_df.iloc[0]
df = pd.read_csv(feats_csv_path)
# feats = shuffle(df).reset_index(drop=True)
# feats = feats.to_numpy()
feats = df.to_numpy()
label = np.zeros(self.n_classes)
if self.n_classes==2:
label[1] = csv_file_df.iloc[1]
else:
if int(csv_file_df.iloc[1])<=(len(label)-1):
label[int(csv_file_df.iloc[1])] = 1
return feats, label
def __len__(self):
if self.train:
return len(self.train_path)
else:
return len(self.test_path)
def __getitem__(self, index):
if self.cache:
label = self.labels[index]
feats = self.features[index]
label = Variable(Tensor(label))
feats = Variable(Tensor(feats)).view(-1, 512)
return feats, label
else:
if self.train:
feats, label = self.get_bag_feats(self.train_path.iloc[index])
label = Variable(Tensor(label))
feats = Variable(Tensor(feats)).view(-1, 512)
else:
feats, label = self.get_bag_feats(self.test_path.iloc[index])
label = Variable(Tensor(label))
feats = Variable(Tensor(feats)).view(-1, 512)
return feats, label
if __name__ == '__main__':
import os
cwd = os.getcwd()
home = cwd.split('/')[1]
data_root = f'/{home}/ylan/RCC_project/rcc_classification/datasets/Camelyon16/Camelyon16.csv'
dataset = FeatureBagLoader(data_root, cache=False)
for i in dataset:
# print(i[1])
# print(i)
features, label = i
print(label)
# print(features.shape)
# print(label[0].long())
\ No newline at end of file
import h5py
# import helpers
import numpy as np
from pathlib import Path
import torch
# from torch._C import long
from torch.utils import data
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
# from histoTransforms import RandomHueSaturationValue
import torchvision.transforms as transforms
import torch.nn.functional as F
import csv
from PIL import Image
import cv2
import pandas as pd
import json
class HDF5MILDataloader(data.Dataset):
"""Represents an abstract HDF5 dataset. For single H5 container!
Input params:
file_path: Path to the folder containing the dataset (one or multiple HDF5 files).
mode: 'train' or 'test'
load_data: If True, loads all the data immediately into RAM. Use this if
the dataset is fits into memory. Otherwise, leave this at false and
the data will load lazily.
data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
"""
def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=20):
super().__init__()
self.data_info = []
self.data_cache = {}
self.slideLabelDict = {}
self.data_cache_size = data_cache_size
self.mode = mode
self.file_path = file_path
# self.csv_path = csv_path
self.label_path = label_path
self.n_classes = n_classes
self.bag_size = 120
# self.label_file = label_path
recursive = True
# read labels and slide_path from csv
# df = pd.read_csv(self.csv_path)
# labels = df.LABEL
# slides = df.FILENAME
with open(self.label_path, 'r') as f:
self.slideLabelDict = json.load(f)[mode]
self.slideLabelDict = {Path(x).stem : y for (x,y) in self.slideLabelDict}
# if Path(slides[0]).suffix:
# slides = list(map(lambda x: Path(x).stem, slides))
# print(labels)
# print(slides)
# self.slideLabelDict = dict(zip(slides, labels))
# print(self.slideLabelDict)
#check if files in slideLabelDict, only take files that are available.
files_in_path = list(Path(self.file_path).rglob('*.hdf5'))
files_in_path = [x.stem for x in files_in_path]
# print(len(files_in_path))
# print(files_in_path)
# print(list(self.slideLabelDict.keys()))
# for x in list(self.slideLabelDict.keys()):
# if x in files_in_path:
# path = Path(self.file_path) / (x + '.hdf5')
# print(path)
self.files = [Path(self.file_path)/ (x + '.hdf5') for x in list(self.slideLabelDict.keys()) if x in files_in_path]
print(len(self.files))
# self.files = list(map(lambda x: Path(self.file_path) / (Path(x).stem + '.hdf5'), list(self.slideLabelDict.keys())))
for h5dataset_fp in tqdm(self.files):
# print(h5dataset_fp)
self._add_data_infos(str(h5dataset_fp.resolve()), load_data)
# print(self.data_info)
self.resize_transforms = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(256),
])
self.img_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
# histoTransforms.AutoRandomRotation(),
transforms.Lambda(lambda a: np.array(a)),
])
self.hsv_transforms = transforms.Compose([
RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)),
transforms.ToTensor()
])
# self._add_data_infos(load_data)
def __getitem__(self, index):
# get data
batch, label, name = self.get_data(index)
out_batch = []
if self.mode == 'train':
# print(img)
# print(img.shape)
for img in batch:
img = self.img_transforms(img)
img = self.hsv_transforms(img)
out_batch.append(img)
else:
for img in batch:
img = transforms.functional.to_tensor(img)
out_batch.append(img)
if len(out_batch) == 0:
# print(name)
out_batch = torch.randn(100,3,256,256)
else: out_batch = torch.stack(out_batch)
out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch
label = torch.as_tensor(label)
label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
return out_batch, label, name
def __len__(self):
return len(self.data_info)
def _add_data_infos(self, file_path, load_data):
wsi_name = Path(file_path).stem
if wsi_name in self.slideLabelDict:
label = self.slideLabelDict[wsi_name]
wsi_batch = []
# with h5py.File(file_path, 'r') as h5_file:
# numKeys = len(h5_file.keys())
# sample = list(h5_file.keys())[0]
# shape = (numKeys,) + h5_file[sample][:].shape
# for tile in h5_file.keys():
# img = h5_file[tile][:]
# print(img)
# if type == 'images':
# t = 'data'
# else:
# t = 'label'
idx = -1
# if load_data:
# for tile in h5_file.keys():
# img = h5_file[tile][:]
# img = img.astype(np.uint8)
# img = self.resize_transforms(img)
# wsi_batch.append(img)
# idx = self._add_to_cache(wsi_batch, file_path)
# wsi_batch.append(img)
# self.data_info.append({'data_path': file_path, 'label': label, 'shape': shape, 'name': wsi_name, 'cache_idx': idx})
self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx})
def _load_data(self, file_path):
"""Load data to the cache given the file
path and update the cache index in the
data_info structure.
"""
with h5py.File(file_path, 'r') as h5_file:
wsi_batch = []
for tile in h5_file.keys():
img = h5_file[tile][:]
img = img.astype(np.uint8)
img = self.resize_transforms(img)
wsi_batch.append(img)
idx = self._add_to_cache(wsi_batch, file_path)
file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path)
self.data_info[file_idx + idx]['cache_idx'] = idx
# for type in ['images', 'labels']:
# for key in tqdm(h5_file[f'{self.mode}/{type}'].keys()):
# img = h5_file[data_path][:]
# idx = self._add_to_cache(img, data_path)
# file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == data_path)
# self.data_info[file_idx + idx]['cache_idx'] = idx
# for gname, group in h5_file.items():
# for dname, ds in group.items():
# # add data to the data cache and retrieve
# # the cache index
# idx = self._add_to_cache(ds.value, file_path)
# # find the beginning index of the hdf5 file we are looking for
# file_idx = next(i for i,v in enumerate(self.data_info) if v['file_path'] == file_path)
# # the data info should have the same index since we loaded it in the same way
# self.data_info[file_idx + idx]['cache_idx'] = idx
# remove an element from data cache if size was exceeded
if len(self.data_cache) > self.data_cache_size:
# remove one item from the cache at random
removal_keys = list(self.data_cache)
removal_keys.remove(file_path)
self.data_cache.pop(removal_keys[0])
# remove invalid cache_idx
# self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'shape': di['shape'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
def _add_to_cache(self, data, data_path):
"""Adds data to the cache and returns its index. There is one cache
list for every file_path, containing all datasets in that file.
"""
if data_path not in self.data_cache:
self.data_cache[data_path] = [data]
else:
self.data_cache[data_path].append(data)
return len(self.data_cache[data_path]) - 1
# def get_data_infos(self, type):
# """Get data infos belonging to a certain type of data.
# """
# data_info_type = [di for di in self.data_info if di['type'] == type]
# return data_info_type
def get_name(self, i):
# name = self.get_data_infos(type)[i]['name']
name = self.data_info[i]['name']
return name
def get_data(self, i):
"""Call this function anytime you want to access a chunk of data from the
dataset. This will make sure that the data is loaded in case it is
not part of the data cache.
i = index
"""
# fp = self.get_data_infos(type)[i]['data_path']
fp = self.data_info[i]['data_path']
if fp not in self.data_cache:
self._load_data(fp)
# get new cache_idx assigned by _load_data_info
# cache_idx = self.get_data_infos(type)[i]['cache_idx']
cache_idx = self.data_info[i]['cache_idx']
label = self.data_info[i]['label']
name = self.data_info[i]['name']
# print(self.data_cache[fp][cache_idx])
return self.data_cache[fp][cache_idx], label, name
class RandomHueSaturationValue(object):
def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5):
self.hue_shift_limit = hue_shift_limit
self.sat_shift_limit = sat_shift_limit
self.val_shift_limit = val_shift_limit
self.p = p
def __call__(self, sample):
img = sample #,lbl
if np.random.random() < self.p:
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32
h, s, v = cv2.split(img)
hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1)
hue_shift = np.uint8(hue_shift)
h += hue_shift
sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1])
s = cv2.add(s, sat_shift)
val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1])
v = cv2.add(v, val_shift)
img = cv2.merge((h, s, v))
img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
return img #, lbl
if __name__ == '__main__':
from pathlib import Path
import os
home = Path.cwd().parts[1]
train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
# label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
label_path = f'/{home}/ylan/DeepGraft/training_tables/split_Aachen_PAS_all.json'
output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/'
# os.makedirs(output_path, exist_ok=True)
dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=6)
data = DataLoader(dataset, batch_size=1)
# print(len(dataset))
x = 0
c = 0
for item in data:
if c >=10:
break
bag, label, name = item
print(bag)
# # print(bag.shape)
# if bag.shape[1] == 1:
# print(name)
# print(bag.shape)
# print(bag.shape)
# print(name)
# out_dir = Path(output_path) / name
# os.makedirs(out_dir, exist_ok=True)
# # print(item[2])
# # print(len(item))
# # print(item[1])
# # print(data.shape)
# # data = data.squeeze()
# bag = item[0]
# bag = bag.squeeze()
# for i in range(bag.shape[0]):
# img = bag[i, :, :, :]
# img = img.squeeze()
# img = img*255
# img = img.numpy().astype(np.uint8).transpose(1,2,0)
# img = Image.fromarray(img)
# img = img.convert('RGB')
# img.save(f'{out_dir}/{i}.png')
c += 1
# else: break
# print(data.shape)
# print(label)
\ No newline at end of file
import inspect # 查看python 类的参数和模块、函数代码 import inspect # 查看python 类的参数和模块、函数代码
import importlib # In order to dynamically import the library import importlib # In order to dynamically import the library
from typing import Optional
import pytorch_lightning as pl import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader from torch.utils.data import random_split, DataLoader
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
from torchvision import transforms from torchvision import transforms
from .camel_dataloader import FeatureBagLoader
from .custom_dataloader import HDF5MILDataloader
from pathlib import Path
class DataInterface(pl.LightningDataModule): class DataInterface(pl.LightningDataModule):
...@@ -24,6 +28,8 @@ class DataInterface(pl.LightningDataModule): ...@@ -24,6 +28,8 @@ class DataInterface(pl.LightningDataModule):
self.dataset_name = dataset_name self.dataset_name = dataset_name
self.kwargs = kwargs self.kwargs = kwargs
self.load_data_module() self.load_data_module()
home = Path.cwd().parts[1]
self.data_root = f'/{home}/ylan/RCC_project/rcc_classification/datasets/Camelyon16/Camelyon16.csv'
...@@ -46,14 +52,23 @@ class DataInterface(pl.LightningDataModule): ...@@ -46,14 +52,23 @@ class DataInterface(pl.LightningDataModule):
""" """
# Assign train/val datasets for use in dataloaders # Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None: if stage == 'fit' or stage is None:
self.train_dataset = self.instancialize(state='train') dataset = FeatureBagLoader(data_root = self.data_root,
self.val_dataset = self.instancialize(state='val') train=True)
a = int(len(dataset)* 0.8)
b = int(len(dataset) - a)
print(a)
print(b)
self.train_dataset, self.val_dataset = random_split(dataset, [a, b])
# self.train_dataset = self.instancialize(state='train')
# self.val_dataset = self.instancialize(state='val')
# Assign test dataset for use in dataloader(s) # Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None: if stage == 'test' or stage is None:
# self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform) # self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
self.test_dataset = self.instancialize(state='test') self.test_dataset = FeatureBagLoader(data_root = self.data_root,
train=False)
# self.test_dataset = self.instancialize(state='test')
def train_dataloader(self): def train_dataloader(self):
...@@ -88,3 +103,61 @@ class DataInterface(pl.LightningDataModule): ...@@ -88,3 +103,61 @@ class DataInterface(pl.LightningDataModule):
args1[arg] = self.kwargs[arg] args1[arg] = self.kwargs[arg]
args1.update(other_args) args1.update(other_args)
return self.data_module(**args1) return self.data_module(**args1)
class MILDataModule(pl.LightningDataModule):
def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, *args, **kwargs):
super().__init__()
self.data_root = data_root
self.label_path = label_path
self.batch_size = batch_size
self.num_workers = num_workers
self.image_size = 384
self.n_classes = n_classes
self.target_number = 9
self.mean_bag_length = 10
self.var_bag_length = 2
self.num_bags_train = 200
self.num_bags_test = 50
self.seed = 1
self.cache = True
def setup(self, stage: Optional[str] = None) -> None:
# if self.n_classes == 2:
# if stage in (None, 'fit'):
# dataset = HDF5Dataset(self.data_root, mode='train', n_classes=self.n_classes)
# a = int(len(dataset)* 0.8)
# b = int(len(dataset) - a)
# self.train_data, self.valid_data = random_split(dataset, [a, b])
# if stage in (None, 'test'):
# self.test_data = HDF5Dataset(self.data_root, mode='test', n_classes=self.n_classes)
# else:
home = Path.cwd().parts[1]
# self.label_path = f'{home}/ylan/DeepGraft_project/code/split_debug.json'
# train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train_small.csv'
# test_csv = f'/{home}/ylan/DeepGraft_project/code/debug_test_small.csv'
if stage in (None, 'fit'):
dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes)
# print(len(dataset))
a = int(len(dataset)* 0.8)
b = int(len(dataset) - a)
self.train_data, self.valid_data = random_split(dataset, [a, b])
if stage in (None, 'test'):
self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes)
return super().setup(stage=stage)
def train_dataloader(self) -> DataLoader:
return DataLoader(self.train_data, self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True,
def val_dataloader(self) -> DataLoader:
return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers)
def test_dataloader(self) -> DataLoader:
return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers)
\ No newline at end of file
...@@ -47,7 +47,8 @@ class TransMIL(nn.Module): ...@@ -47,7 +47,8 @@ class TransMIL(nn.Module):
def __init__(self, n_classes): def __init__(self, n_classes):
super(TransMIL, self).__init__() super(TransMIL, self).__init__()
self.pos_layer = PPEG(dim=512) self.pos_layer = PPEG(dim=512)
self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU()) self._fc1 = nn.Sequential(nn.Linear(512, 512), nn.ReLU())
# self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU())
self.cls_token = nn.Parameter(torch.randn(1, 1, 512)) self.cls_token = nn.Parameter(torch.randn(1, 1, 512))
self.n_classes = n_classes self.n_classes = n_classes
self.layer1 = TransLayer(dim=512) self.layer1 = TransLayer(dim=512)
...@@ -56,11 +57,10 @@ class TransMIL(nn.Module): ...@@ -56,11 +57,10 @@ class TransMIL(nn.Module):
self._fc2 = nn.Linear(512, self.n_classes) self._fc2 = nn.Linear(512, self.n_classes)
def forward(self, **kwargs): def forward(self, **kwargs): #, **kwargs
h = kwargs['data'].float() #[B, n, 1024] h = kwargs['data'].float() #[B, n, 1024]
# h = self._fc1(h) #[B, n, 512]
h = self._fc1(h) #[B, n, 512]
#---->pad #---->pad
H = h.shape[1] H = h.shape[1]
...@@ -86,15 +86,19 @@ class TransMIL(nn.Module): ...@@ -86,15 +86,19 @@ class TransMIL(nn.Module):
h = self.norm(h)[:,0] h = self.norm(h)[:,0]
#---->predict #---->predict
logits = self._fc2(h) #[B, n_classes] logits = self._fc2(torch.sigmoid(h)) #[B, n_classes]
Y_hat = torch.argmax(logits, dim=1) Y_hat = torch.argmax(logits, dim=1)
Y_prob = F.softmax(logits, dim = 1) Y_prob = F.softmax(logits, dim = 1)
results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat} results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat}
return results_dict return results_dict
if __name__ == "__main__": if __name__ == "__main__":
data = torch.randn((1, 6000, 1024)).cuda() data = torch.randn((1, 6000, 512)).cuda()
model = TransMIL(n_classes=2).cuda() model = TransMIL(n_classes=2).cuda()
print(model.eval()) print(model.eval())
results_dict = model(data = data) results_dict = model(data = data)
print(results_dict) print(results_dict)
logits = results_dict['logits']
Y_prob = results_dict['Y_prob']
Y_hat = results_dict['Y_hat']
# print(F.sigmoid(logits))
File added
File added
File added
File added
...@@ -4,6 +4,9 @@ import inspect ...@@ -4,6 +4,9 @@ import inspect
import importlib import importlib
import random import random
import pandas as pd import pandas as pd
import seaborn as sns
from pathlib import Path
from matplotlib import pyplot as plt
#----> #---->
from MyOptimizer import create_optimizer from MyOptimizer import create_optimizer
...@@ -18,7 +21,9 @@ import torchmetrics ...@@ -18,7 +21,9 @@ import torchmetrics
#----> #---->
import pytorch_lightning as pl import pytorch_lightning as pl
from .vision_transformer import vit_small
from torchvision import models
from torchvision.models import resnet
class ModelInterface(pl.LightningModule): class ModelInterface(pl.LightningModule):
...@@ -37,11 +42,11 @@ class ModelInterface(pl.LightningModule): ...@@ -37,11 +42,11 @@ class ModelInterface(pl.LightningModule):
#---->Metrics #---->Metrics
if self.n_classes > 2: if self.n_classes > 2:
self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'macro') self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted')
metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes, metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes,
average='micro'), average='micro'),
torchmetrics.CohenKappa(num_classes = self.n_classes), torchmetrics.CohenKappa(num_classes = self.n_classes),
torchmetrics.F1(num_classes = self.n_classes, torchmetrics.F1Score(num_classes = self.n_classes,
average = 'macro'), average = 'macro'),
torchmetrics.Recall(average = 'macro', torchmetrics.Recall(average = 'macro',
num_classes = self.n_classes), num_classes = self.n_classes),
...@@ -49,17 +54,19 @@ class ModelInterface(pl.LightningModule): ...@@ -49,17 +54,19 @@ class ModelInterface(pl.LightningModule):
num_classes = self.n_classes), num_classes = self.n_classes),
torchmetrics.Specificity(average = 'macro', torchmetrics.Specificity(average = 'macro',
num_classes = self.n_classes)]) num_classes = self.n_classes)])
else : else :
self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'macro') self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'weighted')
metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2, metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2,
average = 'micro'), average = 'micro'),
torchmetrics.CohenKappa(num_classes = 2), torchmetrics.CohenKappa(num_classes = 2),
torchmetrics.F1(num_classes = 2, torchmetrics.F1Score(num_classes = 2,
average = 'macro'), average = 'macro'),
torchmetrics.Recall(average = 'macro', torchmetrics.Recall(average = 'macro',
num_classes = 2), num_classes = 2),
torchmetrics.Precision(average = 'macro', torchmetrics.Precision(average = 'macro',
num_classes = 2)]) num_classes = 2)])
self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes)
self.valid_metrics = metrics.clone(prefix = 'val_') self.valid_metrics = metrics.clone(prefix = 'val_')
self.test_metrics = metrics.clone(prefix = 'test_') self.test_metrics = metrics.clone(prefix = 'test_')
...@@ -67,18 +74,103 @@ class ModelInterface(pl.LightningModule): ...@@ -67,18 +74,103 @@ class ModelInterface(pl.LightningModule):
self.shuffle = kargs['data'].data_shuffle self.shuffle = kargs['data'].data_shuffle
self.count = 0 self.count = 0
self.out_features = 512
if kargs['backbone'] == 'dino':
#---> dino feature extractor
arch = 'vit_small'
patch_size = 16
n_last_blocks = 4
# num_labels = 1000
avgpool_patchtokens = False
home = Path.cwd().parts[1]
weight_path = f'/{home}/ylan/workspace/dino/output/Aachen_2/checkpoint.pth'
model = vit_small(patch_size, num_classes=0)
# model.eval()
# set_parameter_requires_grad(model, feature_extracting)
for param in model.parameters():
param.requires_grad = False
# print(model.embed_dim)
# embed_dim = model.embed_dim * (n_last_blocks + int(avgpool_patchtokens))
# model.eval()
# print(embed_dim)
linear = nn.Linear(model.embed_dim, self.out_features)
linear.weight.data.normal_(mean=0.0, std=0.01)
linear.bias.data.zero_()
self.model_ft = nn.Sequential(
model,
linear,
)
elif kargs['backbone'] == 'resnet18':
resnet18 = models.resnet18(pretrained=True)
modules = list(resnet18.children())[:-1]
# model_ft.fc = nn.Linear(512, out_features)
res18 = nn.Sequential(
*modules,
)
for param in res18.parameters():
param.requires_grad = False
self.model_ft = nn.Sequential(
res18,
nn.AdaptiveAvgPool2d(1),
View((-1, 512)),
nn.Linear(512, self.out_features),
nn.ReLU(),
)
elif kargs['backbone'] == 'resnet50':
resnet50 = models.resnet50(pretrained=True)
# model_ft.fc = nn.Linear(1024, out_features)
modules = list(resnet50.children())[:-3]
res50 = nn.Sequential(
*modules,
)
for param in res50.parameters():
param.requires_grad = False
self.model_ft = nn.Sequential(
res50,
nn.AdaptiveAvgPool2d(1),
View((-1, 1024)),
nn.Linear(1024, self.out_features),
nn.ReLU()
)
elif kargs['backbone'] == 'simple': #mil-ab attention
feature_extracting = False
self.model_ft = nn.Sequential(
nn.Conv2d(3, 20, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(20, 50, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(2, stride=2),
View((-1, 1024)),
nn.Linear(1024, self.out_features),
nn.ReLU(),
)
#---->remove v_num #---->remove v_num
def get_progress_bar_dict(self): # def get_progress_bar_dict(self):
# don't show the version number # # don't show the version number
items = super().get_progress_bar_dict() # items = super().get_progress_bar_dict()
items.pop("v_num", None) # items.pop("v_num", None)
return items # return items
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
#---->inference #---->inference
data, label = batch data, label, _ = batch
results_dict = self.model(data=data, label=label) label = label.float()
data = data.squeeze(0)
# print(data.shape)
features = self.model_ft(data)
features = features.unsqueeze(0)
# print(features.shape)
# features = features.squeeze()
results_dict = self.model(data=features)
# results_dict = self.model(data=data, label=label)
logits = results_dict['logits'] logits = results_dict['logits']
Y_prob = results_dict['Y_prob'] Y_prob = results_dict['Y_prob']
Y_hat = results_dict['Y_hat'] Y_hat = results_dict['Y_hat']
...@@ -87,8 +179,13 @@ class ModelInterface(pl.LightningModule): ...@@ -87,8 +179,13 @@ class ModelInterface(pl.LightningModule):
loss = self.loss(logits, label) loss = self.loss(logits, label)
#---->acc log #---->acc log
# print(label)
Y_hat = int(Y_hat) Y_hat = int(Y_hat)
Y = int(label) # if self.n_classes == 2:
# Y = int(label[0][1])
# else:
Y = torch.argmax(label)
# Y = int(label[0])
self.data[Y]["count"] += 1 self.data[Y]["count"] += 1
self.data[Y]["correct"] += (Y_hat == Y) self.data[Y]["correct"] += (Y_hat == Y)
...@@ -106,19 +203,28 @@ class ModelInterface(pl.LightningModule): ...@@ -106,19 +203,28 @@ class ModelInterface(pl.LightningModule):
self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
data, label = batch
results_dict = self.model(data=data, label=label) data, label, _ = batch
label = label.float()
data = data.squeeze(0)
features = self.model_ft(data)
features = features.unsqueeze(0)
results_dict = self.model(data=features)
logits = results_dict['logits'] logits = results_dict['logits']
Y_prob = results_dict['Y_prob'] Y_prob = results_dict['Y_prob']
Y_hat = results_dict['Y_hat'] Y_hat = results_dict['Y_hat']
#---->acc log #---->acc log
Y = int(label) # Y = int(label[0][1])
Y = torch.argmax(label)
self.data[Y]["count"] += 1 self.data[Y]["count"] += 1
self.data[Y]["correct"] += (Y_hat.item() == Y) self.data[Y]["correct"] += (Y_hat.item() == Y)
return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label} return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y}
def validation_epoch_end(self, val_step_outputs): def validation_epoch_end(self, val_step_outputs):
...@@ -126,13 +232,26 @@ class ModelInterface(pl.LightningModule): ...@@ -126,13 +232,26 @@ class ModelInterface(pl.LightningModule):
probs = torch.cat([x['Y_prob'] for x in val_step_outputs], dim = 0) probs = torch.cat([x['Y_prob'] for x in val_step_outputs], dim = 0)
max_probs = torch.stack([x['Y_hat'] for x in val_step_outputs]) max_probs = torch.stack([x['Y_hat'] for x in val_step_outputs])
target = torch.stack([x['label'] for x in val_step_outputs], dim = 0) target = torch.stack([x['label'] for x in val_step_outputs], dim = 0)
#----> #---->
# logits = logits.long()
# target = target.squeeze().long()
# logits = logits.squeeze(0)
self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True) self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True)
self.log('auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True) self.log('auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True)
self.log_dict(self.valid_metrics(max_probs.squeeze() , target.squeeze()),
# print(max_probs.squeeze(0).shape)
# print(target.shape)
self.log_dict(self.valid_metrics(max_probs.squeeze() , target),
on_epoch = True, logger = True) on_epoch = True, logger = True)
#----> log confusion matrix
confmat = self.confusion_matrix(max_probs.squeeze(), target)
df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
plt.figure()
fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
plt.close(fig_)
self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch)
#---->acc log #---->acc log
for c in range(self.n_classes): for c in range(self.n_classes):
count = self.data[c]["count"] count = self.data[c]["count"]
...@@ -156,18 +275,24 @@ class ModelInterface(pl.LightningModule): ...@@ -156,18 +275,24 @@ class ModelInterface(pl.LightningModule):
return [optimizer] return [optimizer]
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
data, label = batch
results_dict = self.model(data=data, label=label) data, label, _ = batch
label = label.float()
data = data.squeeze(0)
features = self.model_ft(data)
features = features.unsqueeze(0)
results_dict = self.model(data=features, label=label)
logits = results_dict['logits'] logits = results_dict['logits']
Y_prob = results_dict['Y_prob'] Y_prob = results_dict['Y_prob']
Y_hat = results_dict['Y_hat'] Y_hat = results_dict['Y_hat']
#---->acc log #---->acc log
Y = int(label) Y = torch.argmax(label)
self.data[Y]["count"] += 1 self.data[Y]["count"] += 1
self.data[Y]["correct"] += (Y_hat.item() == Y) self.data[Y]["correct"] += (Y_hat.item() == Y)
return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label} return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y}
def test_epoch_end(self, output_results): def test_epoch_end(self, output_results):
probs = torch.cat([x['Y_prob'] for x in output_results], dim = 0) probs = torch.cat([x['Y_prob'] for x in output_results], dim = 0)
...@@ -176,12 +301,20 @@ class ModelInterface(pl.LightningModule): ...@@ -176,12 +301,20 @@ class ModelInterface(pl.LightningModule):
#----> #---->
auc = self.AUROC(probs, target.squeeze()) auc = self.AUROC(probs, target.squeeze())
metrics = self.test_metrics(max_probs.squeeze() , target.squeeze()) metrics = self.test_metrics(max_probs.squeeze() , target)
metrics['auc'] = auc
# metrics = self.test_metrics(max_probs.squeeze() , torch.argmax(target.squeeze(), dim=1))
metrics['test_auc'] = auc
# self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True)
# print(max_probs.squeeze(0).shape)
# print(target.shape)
# self.log_dict(metrics, logger = True)
for keys, values in metrics.items(): for keys, values in metrics.items():
print(f'{keys} = {values}') print(f'{keys} = {values}')
metrics[keys] = values.cpu().numpy() metrics[keys] = values.cpu().numpy()
print()
#---->acc log #---->acc log
for c in range(self.n_classes): for c in range(self.n_classes):
count = self.data[c]["count"] count = self.data[c]["count"]
...@@ -192,6 +325,16 @@ class ModelInterface(pl.LightningModule): ...@@ -192,6 +325,16 @@ class ModelInterface(pl.LightningModule):
acc = float(correct) / count acc = float(correct) / count
print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count)) print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
confmat = self.confusion_matrix(max_probs.squeeze(), target)
df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
plt.figure()
fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
# plt.close(fig_)
# self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch)
plt.savefig(f'{self.log_path}/cm_test')
plt.close(fig_)
#----> #---->
result = pd.DataFrame([metrics]) result = pd.DataFrame([metrics])
result.to_csv(self.log_path / 'result.csv') result.to_csv(self.log_path / 'result.csv')
...@@ -227,3 +370,17 @@ class ModelInterface(pl.LightningModule): ...@@ -227,3 +370,17 @@ class ModelInterface(pl.LightningModule):
args1[arg] = getattr(self.hparams.model, arg) args1[arg] = getattr(self.hparams.model, arg)
args1.update(other_args) args1.update(other_args)
return Model(**args1) return Model(**args1)
class View(nn.Module):
def __init__(self, shape):
super().__init__()
self.shape = shape
def forward(self, input):
'''
Reshapes the input according to the shape saved in the view data structure.
'''
# batch_size = input.size(0)
# shape = (batch_size, *self.shape)
out = input.view(*self.shape)
return out
\ No newline at end of file
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Mostly copy-paste from timm library.
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
import math
from functools import partial
import torch
import torch.nn as nn
# from utils import trunc_normal_
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, return_attention=False):
y, attn = self.attn(self.norm1(x))
if return_attention:
return attn
x = x + self.drop_path(y)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
num_patches = (img_size // patch_size) * (img_size // patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class VisionTransformer(nn.Module):
""" Vision Transformer """
def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
super().__init__()
self.num_features = self.embed_dim = embed_dim
self.patch_embed = PatchEmbed(
img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# Classifier head
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def interpolate_pos_encoding(self, x, w, h):
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
class_pos_embed = self.pos_embed[:, 0]
patch_pos_embed = self.pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_embed.patch_size
h0 = h // self.patch_embed.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode='bicubic',
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def prepare_tokens(self, x):
B, nc, w, h = x.shape
x = self.patch_embed(x) # patch linear embedding
# add the [CLS] token to the embed patch tokens
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# add positional encoding to each token
x = x + self.interpolate_pos_encoding(x, w, h)
return self.pos_drop(x)
def forward(self, x):
x = self.prepare_tokens(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x[:, 0]
def get_last_selfattention(self, x):
x = self.prepare_tokens(x)
for i, blk in enumerate(self.blocks):
if i < len(self.blocks) - 1:
x = blk(x)
else:
# return attention of the last block
return blk(x, return_attention=True)
def get_intermediate_layers(self, x, n=1):
x = self.prepare_tokens(x)
# we return the output tokens from the `n` last blocks
output = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if len(self.blocks) - i <= n:
output.append(self.norm(x))
return output
def vit_tiny(patch_size=16, **kwargs):
model = VisionTransformer(
patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def vit_small(patch_size=16, **kwargs):
model = VisionTransformer(
patch_size=patch_size, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, #num_heads=6
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def vit_base(patch_size=16, **kwargs):
model = VisionTransformer(
patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
class DINOHead(nn.Module):
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
super().__init__()
nlayers = max(nlayers, 1)
if nlayers == 1:
self.mlp = nn.Linear(in_dim, bottleneck_dim)
else:
layers = [nn.Linear(in_dim, hidden_dim)]
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
for _ in range(nlayers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim))
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
self.mlp = nn.Sequential(*layers)
self.apply(self._init_weights)
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.mlp(x)
x = nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
return x
...@@ -3,8 +3,9 @@ from pathlib import Path ...@@ -3,8 +3,9 @@ from pathlib import Path
import numpy as np import numpy as np
import glob import glob
from datasets import DataInterface from datasets.data_interface import DataInterface, MILDataModule
from models import ModelInterface from models.model_interface import ModelInterface
import models.vision_transformer as vits
from utils.utils import * from utils.utils import *
# pytorch_lightning # pytorch_lightning
...@@ -15,8 +16,8 @@ from pytorch_lightning import Trainer ...@@ -15,8 +16,8 @@ from pytorch_lightning import Trainer
def make_parse(): def make_parse():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--stage', default='train', type=str) parser.add_argument('--stage', default='train', type=str)
parser.add_argument('--config', default='Camelyon/TransMIL.yaml',type=str) parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str)
parser.add_argument('--gpus', default = [2]) # parser.add_argument('--gpus', default = [2])
parser.add_argument('--fold', default = 0) parser.add_argument('--fold', default = 0)
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -34,20 +35,31 @@ def main(cfg): ...@@ -34,20 +35,31 @@ def main(cfg):
cfg.callbacks = load_callbacks(cfg) cfg.callbacks = load_callbacks(cfg)
#---->Define Data #---->Define Data
DataInterface_dict = {'train_batch_size': cfg.Data.train_dataloader.batch_size, # DataInterface_dict = {'train_batch_size': cfg.Data.train_dataloader.batch_size,
'train_num_workers': cfg.Data.train_dataloader.num_workers, # 'train_num_workers': cfg.Data.train_dataloader.num_workers,
'test_batch_size': cfg.Data.test_dataloader.batch_size, # 'test_batch_size': cfg.Data.test_dataloader.batch_size,
'test_num_workers': cfg.Data.test_dataloader.num_workers, # 'test_num_workers': cfg.Data.test_dataloader.num_workers,
'dataset_name': cfg.Data.dataset_name, # 'dataset_name': cfg.Data.dataset_name,
'dataset_cfg': cfg.Data,} # 'dataset_cfg': cfg.Data,}
dm = DataInterface(**DataInterface_dict) # dm = DataInterface(**DataInterface_dict)
home = Path.cwd().parts[1]
DataInterface_dict = {
'data_root': cfg.Data.data_dir,
'label_path': cfg.Data.label_file,
'batch_size': cfg.Data.train_dataloader.batch_size,
'num_workers': cfg.Data.train_dataloader.num_workers,
'n_classes': cfg.Model.n_classes,
}
dm = MILDataModule(**DataInterface_dict)
#---->Define Model #---->Define Model
ModelInterface_dict = {'model': cfg.Model, ModelInterface_dict = {'model': cfg.Model,
'loss': cfg.Loss, 'loss': cfg.Loss,
'optimizer': cfg.Optimizer, 'optimizer': cfg.Optimizer,
'data': cfg.Data, 'data': cfg.Data,
'log': cfg.log_path 'log': cfg.log_path,
'backbone': cfg.Model.backbone,
} }
model = ModelInterface(**ModelInterface_dict) model = ModelInterface(**ModelInterface_dict)
...@@ -57,12 +69,18 @@ def main(cfg): ...@@ -57,12 +69,18 @@ def main(cfg):
logger=cfg.load_loggers, logger=cfg.load_loggers,
callbacks=cfg.callbacks, callbacks=cfg.callbacks,
max_epochs= cfg.General.epochs, max_epochs= cfg.General.epochs,
min_epochs = 200,
gpus=cfg.General.gpus, gpus=cfg.General.gpus,
amp_level=cfg.General.amp_level, # gpus = [4],
# strategy='ddp',
amp_backend='native',
# amp_level=cfg.General.amp_level,
precision=cfg.General.precision, precision=cfg.General.precision,
accumulate_grad_batches=cfg.General.grad_acc, accumulate_grad_batches=cfg.General.grad_acc,
deterministic=True, # fast_dev_run = True,
check_val_every_n_epoch=1,
# deterministic=True,
check_val_every_n_epoch=10,
) )
#---->train or test #---->train or test
...@@ -83,7 +101,7 @@ if __name__ == '__main__': ...@@ -83,7 +101,7 @@ if __name__ == '__main__':
#---->update #---->update
cfg.config = args.config cfg.config = args.config
cfg.General.gpus = args.gpus # cfg.General.gpus = args.gpus
cfg.General.server = args.stage cfg.General.server = args.stage
cfg.Data.fold = args.fold cfg.Data.fold = args.fold
......
File added
File added
...@@ -14,7 +14,7 @@ def load_loggers(cfg): ...@@ -14,7 +14,7 @@ def load_loggers(cfg):
log_path = cfg.General.log_path log_path = cfg.General.log_path
Path(log_path).mkdir(exist_ok=True, parents=True) Path(log_path).mkdir(exist_ok=True, parents=True)
log_name = Path(cfg.config).parent log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}'
version_name = Path(cfg.config).name[:-5] version_name = Path(cfg.config).name[:-5]
cfg.log_path = Path(log_path) / log_name / version_name / f'fold{cfg.Data.fold}' cfg.log_path = Path(log_path) / log_name / version_name / f'fold{cfg.Data.fold}'
print(f'---->Log dir: {cfg.log_path}') print(f'---->Log dir: {cfg.log_path}')
...@@ -31,8 +31,10 @@ def load_loggers(cfg): ...@@ -31,8 +31,10 @@ def load_loggers(cfg):
#---->load Callback #---->load Callback
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.early_stopping import EarlyStopping
def load_callbacks(cfg): def load_callbacks(cfg):
Mycallbacks = [] Mycallbacks = []
...@@ -47,7 +49,21 @@ def load_callbacks(cfg): ...@@ -47,7 +49,21 @@ def load_callbacks(cfg):
verbose=True, verbose=True,
mode='min' mode='min'
) )
Mycallbacks.append(early_stop_callback) Mycallbacks.append(early_stop_callback)
progress_bar = RichProgressBar(
theme=RichProgressBarTheme(
description='green_yellow',
progress_bar='green1',
progress_bar_finished='green1',
batch_progress='green_yellow',
time='grey82',
processing_speed='grey82',
metrics='grey82'
)
)
Mycallbacks.append(progress_bar)
if cfg.General.server == 'train' : if cfg.General.server == 'train' :
Mycallbacks.append(ModelCheckpoint(monitor = 'val_loss', Mycallbacks.append(ModelCheckpoint(monitor = 'val_loss',
...@@ -64,7 +80,7 @@ def load_callbacks(cfg): ...@@ -64,7 +80,7 @@ def load_callbacks(cfg):
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
def cross_entropy_torch(x, y): def cross_entropy_torch(x, y):
x_softmax = [F.softmax(x[i]) for i in range(len(x))] x_softmax = [F.softmax(x[i], dim=0) for i in range(len(x))]
x_log = torch.tensor([torch.log(x_softmax[i][y[i]]) for i in range(len(y))]) x_log = torch.tensor([torch.log(x_softmax[i][y[i]]) for i in range(y.shape[0])])
loss = - torch.sum(x_log) / len(y) loss = - torch.sum(x_log) / y.shape[0]
return loss return loss
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment