Skip to content
Snippets Groups Projects
Commit 2f11c65a authored by Brian Christopher Wasels's avatar Brian Christopher Wasels
Browse files

UNet 18 Lets go !!!!11

parent 260058fc
Branches
No related tags found
No related merge requests found
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
...@@ -12,6 +12,9 @@ V13: 4 layer, doppel Conv, normDataen,phase 64 ...@@ -12,6 +12,9 @@ V13: 4 layer, doppel Conv, normDataen,phase 64
V14: 4 layer, single conv, normDataen,phase + angle 64 V14: 4 layer, single conv, normDataen,phase + angle 64
V15: 3 layer, doppelte depth Conv pro layer, norm. Daten,kernel 7, phase only, dropout 0.3, 32 V15: 3 layer, doppelte depth Conv pro layer, norm. Daten,kernel 7, phase only, dropout 0.3, 32
V16: 3 layer, doppelte depth Conv pro layer, norm. Daten,kernel 7, angelsonly, dropout 0.5, 32 V16: 3 layer, doppelte depth Conv pro layer, norm. Daten,kernel 7, angelsonly, dropout 0.5, 32
V17: 3 layer, doppelte depth Conv pro layer, norm. Daten,kernel 7, angelsonly, dropout 0.5, 32, but last layer 70 32 1 1
V18: 3 layer, doppelte depth Conv pro layer, norm. Daten,kernel 7, angelsonly, dropout 0.5, 32, like 16 but first layer 6 32 32
V9 mit kernel 7 und nur den phasen: V9 mit kernel 7 und nur den phasen:
mean error over whole set: 16.91116704929035 mean error over whole set: 16.91116704929035
max error average: 292.8658473955995 and maximum 814.873957640188 max error average: 292.8658473955995 and maximum 814.873957640188
......
#like V6_2 but only the different phases as input
"""UNet_V6.ipynb """UNet_V6.ipynb
Automatically generated by Colaboratory. Automatically generated by Colaboratory.
...@@ -25,7 +25,8 @@ class depthwise_separable_conv(nn.Module): ...@@ -25,7 +25,8 @@ class depthwise_separable_conv(nn.Module):
self.pointwise_1 = nn.Conv3d(in_c, out_1_c, kernel_size=1, bias=True) self.pointwise_1 = nn.Conv3d(in_c, out_1_c, kernel_size=1, bias=True)
self.batch_norm_1 = nn.BatchNorm3d(out_1_c) self.batch_norm_1 = nn.BatchNorm3d(out_1_c)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.droptout = nn.Dropout3d(p=0.25) self.droptout = nn.Dropout3d(p=0.5)
self.depthwise_2 = nn.Conv3d(out_1_c, out_1_c, kernel_size= kernel_size, padding=padding[1], groups=out_1_c, bias=True) self.depthwise_2 = nn.Conv3d(out_1_c, out_1_c, kernel_size= kernel_size, padding=padding[1], groups=out_1_c, bias=True)
self.pointwise_2 = nn.Conv3d(out_1_c, out_2_c, kernel_size=1, bias=True) self.pointwise_2 = nn.Conv3d(out_1_c, out_2_c, kernel_size=1, bias=True)
self.batch_norm_2 = nn.BatchNorm3d(out_2_c) self.batch_norm_2 = nn.BatchNorm3d(out_2_c)
...@@ -136,7 +137,7 @@ def accuracy(outputs, labels,normalization, threshold = 0.05): ...@@ -136,7 +137,7 @@ def accuracy(outputs, labels,normalization, threshold = 0.05):
return percentage return percentage
class UNet(UNetBase): class UNet(UNetBase):
def __init__(self,kernel_size = 7, enc_chs=((2,16,32), (32,32,64), (64,64,128)), dec_chs_up=(128, 128, 64), dec_chs_conv=((192,128, 128),(160,64,64),(66,32,32)),normalization=np.array([0,1])): def __init__(self,kernel_size = 7, enc_chs=((6,32,32), (32,64,64), (64,128,128)), dec_chs_up=(128, 128, 64), dec_chs_conv=((192,128, 128),(160,64,64),(70,32,32)),normalization=np.array([0,1])):
super().__init__() super().__init__()
self.encoder = Encoder(kernel_size = kernel_size, chs = enc_chs) self.encoder = Encoder(kernel_size = kernel_size, chs = enc_chs)
self.decoder = Decoder(kernel_size = kernel_size, chs_upsampling = dec_chs_up, chs_conv = dec_chs_conv) self.decoder = Decoder(kernel_size = kernel_size, chs_upsampling = dec_chs_up, chs_conv = dec_chs_conv)
...@@ -174,8 +175,8 @@ def fit(epochs, lr, model, train_loader, val_loader, path, opt_func=torch.optim. ...@@ -174,8 +175,8 @@ def fit(epochs, lr, model, train_loader, val_loader, path, opt_func=torch.optim.
result['train_loss'] = torch.stack(train_losses).mean().item() result['train_loss'] = torch.stack(train_losses).mean().item()
model.epoch_end(epoch, result) model.epoch_end(epoch, result)
history.append(result) history.append(result)
torch.save(model.state_dict(),f'{path}/Unet_dict_V15.pth') torch.save(model.state_dict(),f'{path}/Unet_dict_V18.pth')
torch.save(history,f'{path}/history_V15.pt') torch.save(history,f'{path}/history_V18.pt')
return history return history
def get_default_device(): def get_default_device():
...@@ -225,9 +226,9 @@ def Create_Dataloader(path, batch_size = 100, percent_val = 0.2): ...@@ -225,9 +226,9 @@ def Create_Dataloader(path, batch_size = 100, percent_val = 0.2):
if __name__ == '__main__': if __name__ == '__main__':
#os.chdir('F:/RWTH/HiWi_IEHK/DAMASK3/UNet/Trainingsdata') #os.chdir('F:/RWTH/HiWi_IEHK/DAMASK3/UNet/Trainingsdata')
path_to_rep = '/home/yk138599/Hiwi/damask3' path_to_rep = '/home/yk138599/Hiwi/damask3'
use_seeds = False use_seeds = True
seed = 373686838 seed = 2199910834
num_epochs = 1000 num_epochs = 200
b_size = 32 b_size = 32
opt_func = torch.optim.Adam opt_func = torch.optim.Adam
lr = 0.00003 lr = 0.00003
...@@ -243,8 +244,8 @@ if __name__ == '__main__': ...@@ -243,8 +244,8 @@ if __name__ == '__main__':
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
device = get_default_device() device = get_default_device()
normalization = np.load(f'{path_to_rep}/UNet/Trainingsdata/Norm_min_max_32_phase_only.npy', allow_pickle = True) normalization = np.load(f'{path_to_rep}/UNet/Trainingsdata/Norm_min_max_32_angles.npy', allow_pickle = True)
train_dl, valid_dl = Create_Dataloader(f'{path_to_rep}/UNet/Trainingsdata/TD_norm_32_phase_only.pt', batch_size= b_size ) train_dl, valid_dl = Create_Dataloader(f'{path_to_rep}/UNet/Trainingsdata/TD_norm_32_angles.pt', batch_size= b_size )
train_dl = DeviceDataLoader(train_dl, device) train_dl = DeviceDataLoader(train_dl, device)
valid_dl = DeviceDataLoader(valid_dl, device) valid_dl = DeviceDataLoader(valid_dl, device)
......
No preview for this file type
No preview for this file type
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import torch import torch
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import UNet_V15 as UNet15 import UNet_V15 as UNet15
import UNet_V9_3 as UNet9 import UNet_V9_3 as UNet9
import pyvista as pv import pyvista as pv
from matplotlib.colors import ListedColormap from matplotlib.colors import ListedColormap
import copy import copy
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def predict_stress(image_id, normalization, model, dataset,grain_data, UNet,threshold = 0.1): def predict_stress(image_id, normalization, model, dataset,grain_data, UNet,threshold = 0.1):
input, output = dataset[image_id] input, output = dataset[image_id]
grain,_ = grain_data[image_id] grain,_ = grain_data[image_id]
grain = copy.deepcopy(grain) grain = copy.deepcopy(grain)
grain = torch.unsqueeze(grain,0) grain = torch.unsqueeze(grain,0)
grain = grain.detach().numpy() grain = grain.detach().numpy()
input = copy.deepcopy(input) input = copy.deepcopy(input)
output = copy.deepcopy(output) output = copy.deepcopy(output)
input = torch.unsqueeze(input,0) input = torch.unsqueeze(input,0)
output = torch.unsqueeze(output,0) output = torch.unsqueeze(output,0)
xb = UNet.to_device(input, device_9) xb = UNet.to_device(input, device_9)
model.eval() model.eval()
prediction = model(xb) prediction = model(xb)
input = input.detach().numpy() input = input.detach().numpy()
prediction = prediction.detach().numpy() prediction = prediction.detach().numpy()
output = output.detach().numpy() output = output.detach().numpy()
prediction = rescale(prediction, normalization) prediction = rescale(prediction, normalization)
output = rescale(output, normalization) output = rescale(output, normalization)
error = (abs(output - prediction)/output) error = (abs(output - prediction)/output)
print(f'Maximum error is : {error.max()*100.:.4} %') print(f'Maximum error is : {error.max()*100.:.4} %')
print(f'average error is : {error.mean()*100.:.4} %') print(f'average error is : {error.mean()*100.:.4} %')
right_predic = (error < threshold).sum() right_predic = (error < threshold).sum()
print(f'{(right_predic/error.size)*100.:.4}% of voxels have a diviation less than {threshold*100.}%') print(f'{(right_predic/error.size)*100.:.4}% of voxels have a diviation less than {threshold*100.}%')
grains = grain_matrix_colormap(grain) grains = grain_matrix_colormap(grain)
plot_difference(error,grains,output, threshold) plot_difference(error,grains,output, threshold)
def rescale(output, normalization): def rescale(output, normalization):
output_rescale = output.reshape(output.shape[2],output.shape[3],output.shape[4]) output_rescale = output.reshape(output.shape[2],output.shape[3],output.shape[4])
if normalization is not None: if normalization is not None:
min_label, max_label = normalization min_label, max_label = normalization
output_rescale *= max_label output_rescale *= max_label
output_rescale += min_label output_rescale += min_label
return output_rescale return output_rescale
def get_colormap(mesh, threshold): def get_colormap(mesh, threshold):
black = np.array([11/256, 11/256, 11/256, 1]) black = np.array([11/256, 11/256, 11/256, 1])
yellow = np.array([255/256, 237/256, 0/256, 1]) yellow = np.array([255/256, 237/256, 0/256, 1])
orange = np.array([245/256, 167/256, 0/256, 1]) orange = np.array([245/256, 167/256, 0/256, 1])
red = np.array([203/256, 6/256, 29/256, 1]) red = np.array([203/256, 6/256, 29/256, 1])
bordeaux = np.array([160/256, 15/256, 53/256, 1]) bordeaux = np.array([160/256, 15/256, 53/256, 1])
blue = np.array([0/256, 84/256, 159/256, 1]) blue = np.array([0/256, 84/256, 159/256, 1])
mapping = np.linspace(mesh['error'].min(), mesh['error'].max(),256) mapping = np.linspace(mesh['error'].min(), mesh['error'].max(),256)
newcolors = np.empty((256,4)) newcolors = np.empty((256,4))
newcolors[mapping >=0.15] = red newcolors[mapping >=0.15] = red
#newcolors[mapping <0.5] = red #newcolors[mapping <0.5] = red
#newcolors[mapping <0.3] = orange #newcolors[mapping <0.3] = orange
newcolors[mapping <0.15] = blue newcolors[mapping <0.15] = blue
#newcolors[mapping <0.05] = black #newcolors[mapping <0.05] = black
return ListedColormap(newcolors) return ListedColormap(newcolors)
def plot_losses(history): def plot_losses(history):
train_losses = [x['train_loss'] for x in history[50:]] train_losses = [x['train_loss'] for x in history[50:]]
val_acc = [x['val_acc'] for x in history[50:]] val_acc = [x['val_acc'] for x in history[50:]]
val_loss = [x['val_loss'] for x in history[50:]] val_loss = [x['val_loss'] for x in history[50:]]
fig, ax1 = plt.subplots() fig, ax1 = plt.subplots()
color = 'tab:red' color = 'tab:red'
ax1.set_xlabel('epoch') ax1.set_xlabel('epoch')
ax1.set_ylabel('train/val loss', color=color) ax1.set_ylabel('train/val loss', color=color)
lns1 = ax1.plot(train_losses, color=color, label = 'training loss') lns1 = ax1.plot(train_losses, color=color, label = 'training loss')
lns2 = ax1.plot(val_loss, color='tab:green', label = 'validation loss') lns2 = ax1.plot(val_loss, color='tab:green', label = 'validation loss')
ax1.tick_params(axis='y', labelcolor=color) ax1.tick_params(axis='y', labelcolor=color)
ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
color = 'tab:blue' color = 'tab:blue'
ax2.set_ylabel('validation accuracy', color=color) # we already handled the x-label with ax1 ax2.set_ylabel('validation accuracy', color=color) # we already handled the x-label with ax1
lns3 = ax2.plot(val_acc, color=color, label = 'validation accuracy') lns3 = ax2.plot(val_acc, color=color, label = 'validation accuracy')
ax2.tick_params(axis='y', labelcolor=color) ax2.tick_params(axis='y', labelcolor=color)
lns = lns1+lns2+lns3 lns = lns1+lns2+lns3
labs = [l.get_label() for l in lns] labs = [l.get_label() for l in lns]
ax1.legend(lns, labs, loc=0) ax1.legend(lns, labs, loc=0)
fig.tight_layout() # otherwise the right y-label is slightly clipped fig.tight_layout() # otherwise the right y-label is slightly clipped
plt.title('Loss vs. No. of epochs') plt.title('Loss vs. No. of epochs')
def grain_matrix_colormap(input): def grain_matrix_colormap(input):
matrix_grains = input[0,0,:,:,:] matrix_grains = input[0,0,:,:,:]
matrix_ferrit = input[0,5,:,:,:] #matrix with elements = 1 if the phase is ferrit else 0 matrix_ferrit = input[0,5,:,:,:] #matrix with elements = 1 if the phase is ferrit else 0
#unique_angles = np.unique(matrix_grains) #unique_angles = np.unique(matrix_grains)
matrix_ferrit_grains = np.multiply(matrix_grains, matrix_ferrit)# matrix where only the ferrit grains are nonzero matrix_ferrit_grains = np.multiply(matrix_grains, matrix_ferrit)# matrix where only the ferrit grains are nonzero
index_ferrit_angles = np.unique(matrix_ferrit_grains[matrix_ferrit_grains != 0]) index_ferrit_angles = np.unique(matrix_ferrit_grains[matrix_ferrit_grains != 0])
index_martensite_angles = np.setdiff1d(np.unique(matrix_grains),index_ferrit_angles) index_martensite_angles = np.setdiff1d(np.unique(matrix_grains),index_ferrit_angles)
for index, angle in enumerate(index_ferrit_angles): for index, angle in enumerate(index_ferrit_angles):
matrix_grains[matrix_grains == angle] = (index) # matrix with id for each grain add 1 to perfome the elementwise multiplication to get the index of phase grains matrix_grains[matrix_grains == angle] = (index) # matrix with id for each grain add 1 to perfome the elementwise multiplication to get the index of phase grains
for index, angle in enumerate(index_martensite_angles): for index, angle in enumerate(index_martensite_angles):
matrix_grains[matrix_grains == angle] = (index + len(index_ferrit_angles) +100) # matrix with id for each grain add 1 to perfome the elementwise multiplication to get the index of phase grains matrix_grains[matrix_grains == angle] = (index + len(index_ferrit_angles) +100) # matrix with id for each grain add 1 to perfome the elementwise multiplication to get the index of phase grains
return matrix_grains return matrix_grains
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def plot_difference(error, grains, stress, threshold): def plot_difference(error, grains, stress, threshold):
grid_1 = pv.UniformGrid() grid_1 = pv.UniformGrid()
grid_1.dimensions = np.array(error.shape) +1 grid_1.dimensions = np.array(error.shape) +1
grid_1.spacing = (1,1,1) grid_1.spacing = (1,1,1)
grid_1.cell_data["error"] = error.flatten(order = "F") grid_1.cell_data["error"] = error.flatten(order = "F")
grid_2 = pv.UniformGrid() grid_2 = pv.UniformGrid()
grid_2.dimensions = np.array(grains.shape) +1 grid_2.dimensions = np.array(grains.shape) +1
grid_2.spacing = (1,1,1) grid_2.spacing = (1,1,1)
grid_2.cell_data["grain"] = grains.flatten(order = "F") grid_2.cell_data["grain"] = grains.flatten(order = "F")
grid_3 = pv.UniformGrid() grid_3 = pv.UniformGrid()
grid_3.dimensions = np.array(stress.shape) +1 grid_3.dimensions = np.array(stress.shape) +1
grid_3.spacing = (1,1,1) grid_3.spacing = (1,1,1)
grid_3.cell_data["stress"] = stress.flatten(order = "F") grid_3.cell_data["stress"] = stress.flatten(order = "F")
colormap_error = get_colormap(grid_1, threshold) colormap_error = get_colormap(grid_1, threshold)
p = pv.Plotter(notebook=False,shape=(3,1)) p = pv.Plotter(notebook=False,shape=(3,1))
sargs_grain = dict(height=0.75, vertical=True, position_x=0.1, position_y=0.05, n_labels=0) sargs_grain = dict(height=0.75, vertical=True, position_x=0.1, position_y=0.05, n_labels=0)
sargs_stress = dict(height=0.75, vertical=True, position_x=0.1, position_y=0.05) sargs_stress = dict(height=0.75, vertical=True, position_x=0.1, position_y=0.05)
sargs_error = dict(height=0.75, vertical=True, position_x=0.1, position_y=0.05, n_labels = 0) sargs_error = dict(height=0.75, vertical=True, position_x=0.1, position_y=0.05, n_labels = 0)
def my_plane_func(normal, origin): def my_plane_func(normal, origin):
slc_1 = grid_1.slice(normal=normal, origin=origin) slc_1 = grid_1.slice(normal=normal, origin=origin)
slc_2 = grid_2.slice(normal=normal, origin=origin) slc_2 = grid_2.slice(normal=normal, origin=origin)
slc_3 = grid_3.slice(normal=normal, origin=origin) slc_3 = grid_3.slice(normal=normal, origin=origin)
p.subplot(0,0) p.subplot(0,0)
p.add_mesh(slc_2, name="my_slice_2", cmap = 'RdBu', annotations = annotations_grain, scalar_bar_args=sargs_grain) p.add_mesh(slc_2, name="my_slice_2", cmap = 'RdBu', annotations = annotations_grain, scalar_bar_args=sargs_grain)
p.subplot(2,0) p.subplot(2,0)
p.add_mesh(slc_1, name="my_slice_1", cmap = colormap_error,annotations = annotations_error, scalar_bar_args=sargs_error) p.add_mesh(slc_1, name="my_slice_1", cmap = colormap_error,annotations = annotations_error, scalar_bar_args=sargs_error)
p.subplot(1,0) p.subplot(1,0)
p.add_mesh(slc_3, name="my_slice_3", scalar_bar_args=sargs_stress) p.add_mesh(slc_3, name="my_slice_3", scalar_bar_args=sargs_stress)
p.subplot(0,0) p.subplot(0,0)
annotations_grain = { annotations_grain = {
0: 'Ferrite', 0: 'Ferrite',
grains.max(): 'Martensite', grains.max(): 'Martensite',
} }
annotations_error = { annotations_error = {
0.15: '15%' 0.15: '15%'
} }
p.add_title('Grains',font_size=10) p.add_title('Grains',font_size=10)
p.add_mesh(grid_2 ,opacity=0, cmap = 'RdBu', annotations = annotations_grain, scalar_bar_args=sargs_grain) p.add_mesh(grid_2 ,opacity=0, cmap = 'RdBu', annotations = annotations_grain, scalar_bar_args=sargs_grain)
p.add_plane_widget(my_plane_func) p.add_plane_widget(my_plane_func)
p.subplot(2,0) p.subplot(2,0)
p.add_title('Error',font_size=10) p.add_title('Error',font_size=10)
p.add_mesh(grid_1,scalars = "error",opacity=0, scalar_bar_args=sargs_error) p.add_mesh(grid_1,scalars = "error",opacity=0, scalar_bar_args=sargs_error)
p.add_plane_widget(my_plane_func) p.add_plane_widget(my_plane_func)
p.subplot(1,0) p.subplot(1,0)
p.add_title('Stress',font_size=10) p.add_title('Stress',font_size=10)
p.add_mesh(grid_3,scalars = "stress",annotations = annotations_error ,opacity=0) p.add_mesh(grid_3,scalars = "stress",annotations = annotations_error ,opacity=0)
p.add_plane_widget(my_plane_func) p.add_plane_widget(my_plane_func)
p.link_views() p.link_views()
p.show() p.show()
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
path_to_UNET = 'E:/Data/damask3' path_to_UNET = 'E:/Data/damask3'
UNet = UNet9 UNet = UNet9
Training_data_32 = torch.load(f'{path_to_UNET}/UNet/Input/TD_norm_32_phase.pt') Training_data_32 = torch.load(f'{path_to_UNET}/UNet/Input/TD_norm_32_phase.pt')
normalization_32 = np.load(f'{path_to_UNET}/UNet/Input/Norm_min_max_32_phase.npy', allow_pickle=True) normalization_32 = np.load(f'{path_to_UNET}/UNet/Input/Norm_min_max_32_phase.npy', allow_pickle=True)
grain_data_32 = torch.load(f'{path_to_UNET}/UNet/Input/TD_norm_32_angles.pt') grain_data_32 = torch.load(f'{path_to_UNET}/UNet/Input/TD_norm_32_angles.pt')
model_9 = UNet.UNet() model_9 = UNet.UNet()
model_9.load_state_dict(torch.load(f'{path_to_UNET}/UNet/output/V9_diffLR/Unet_dict_V9_3.pth',map_location=torch.device('cpu'))) model_9.load_state_dict(torch.load(f'{path_to_UNET}/UNet/output/V9_diffLR/Unet_dict_V9_3.pth',map_location=torch.device('cpu')))
device_9 = UNet.get_default_device() device_9 = UNet.get_default_device()
model_9 = UNet.to_device(model_9.double(), device_9) model_9 = UNet.to_device(model_9.double(), device_9)
#sample_index = np.random.randint(low=0, high=len(Training_data_32)) #sample_index = np.random.randint(low=0, high=len(Training_data_32))
sample_index = 1637 sample_index = 1637
print(f'sample number: {sample_index}') print(f'sample number: {sample_index}')
predict_stress(sample_index, normalization = normalization_32, model = model_9, dataset = Training_data_32,grain_data =grain_data_32,UNet=UNet) predict_stress(sample_index, normalization = normalization_32, model = model_9, dataset = Training_data_32,grain_data =grain_data_32,UNet=UNet)
``` ```
%% Output %% Output
no GPU found no GPU found
sample number: 1637 sample number: 1637
Maximum error is : 74.46 % Maximum error is : 74.46 %
average error is : 10.06 % average error is : 10.06 %
60.49% of voxels have a diviation less than 10.0% 60.49% of voxels have a diviation less than 10.0%
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
path_to_UNET = 'E:/Data/damask3' path_to_UNET = 'E:/Data/damask3'
UNet = UNet15 UNet = UNet15
Training_data_64 = torch.load(f'{path_to_UNET}/UNet/Input/TD_norm_64_phase.pt') Training_data_64 = torch.load(f'{path_to_UNET}/UNet/Input/TD_norm_64_phase.pt')
grain_data_64 = torch.load(f'{path_to_UNET}/UNet/Input/TD_norm_64_angles.pt') grain_data_64 = torch.load(f'{path_to_UNET}/UNet/Input/TD_norm_64_angles.pt')
#history = torch.load(f'{path_to_UNET}/UNet/output/V15/history_V15.pt') #history = torch.load(f'{path_to_UNET}/UNet/output/V15/history_V15.pt')
#history_2 = torch.load('E:/Data/damask3/UNet/output/history_test.pt') #history_2 = torch.load('E:/Data/damask3/UNet/output/history_test.pt')
normalization_64 = np.load(f'{path_to_UNET}/UNet/Input/Norm_min_max_64_phase.npy', allow_pickle=True) normalization_64 = np.load(f'{path_to_UNET}/UNet/Input/Norm_min_max_64_phase.npy', allow_pickle=True)
model_15 = UNet.UNet() model_15 = UNet.UNet()
model_15.load_state_dict(torch.load(f'{path_to_UNET}/UNet/output/V15/Unet_dict_V15.pth',map_location=torch.device('cpu'))) model_15.load_state_dict(torch.load(f'{path_to_UNET}/UNet/output/V15/Unet_dict_V15.pth',map_location=torch.device('cpu')))
device_15 = UNet.get_default_device() device_15 = UNet.get_default_device()
model_15 = UNet.to_device(model_15.double(), device_15) model_15 = UNet.to_device(model_15.double(), device_15)
sample_index = np.random.randint(low=0, high=len(Training_data_64)) sample_index = np.random.randint(low=0, high=len(Training_data_64))
print(f'sample number: {sample_index}') print(f'sample number: {sample_index}')
predict_stress(sample_index, normalization = normalization_64, model = model_15, dataset = Training_data_64,UNet = UNet,grain_data =grain_data_64) predict_stress(sample_index, normalization = normalization_64, model = model_15, dataset = Training_data_64,UNet = UNet,grain_data =grain_data_64)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
```
%% Cell type:code id: tags:
``` python
sample_index = np.random.randint(low=0, high=len(Training_data_32)) sample_index = np.random.randint(low=0, high=len(Training_data_32))
print(f'sample number: {sample_index}') print(f'sample number: {sample_index}')
predict_stress(sample_index, normalization = normalization_32, model = model_9, dataset = Training_data_32,grain_data =grain_data_32) predict_stress(sample_index, normalization = normalization_32, model = model_9, dataset = Training_data_32,grain_data =grain_data_32)
``` ```
%% Output %% Output
sample number: 982 sample number: 982
Maximum error is : 142.9 % Maximum error is : 142.9 %
average error is : 15.91 % average error is : 15.91 %
43.13% of voxels have a diviation less than 10.0% 43.13% of voxels have a diviation less than 10.0%
ERROR:root:1: #version 150 ERROR:root:1: #version 150
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
sample_index = np.random.randint(low=0, high=len(Training_data_64)) sample_index = np.random.randint(low=0, high=len(Training_data_64))
print(f'sample number: {sample_index}') print(f'sample number: {sample_index}')
predict_stress(sample_index, normalization = normalization_64, model = model_15, dataset = Training_data_64,grain_data =grain_data_64) predict_stress(sample_index, normalization = normalization_64, model = model_15, dataset = Training_data_64,grain_data =grain_data_64)
``` ```
%% Output %% Output
sample number: 83 sample number: 83
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last) RuntimeError Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_23188/176243172.py in <module> ~\AppData\Local\Temp/ipykernel_23188/176243172.py in <module>
1 sample_index = np.random.randint(low=0, high=len(Training_data_64)) 1 sample_index = np.random.randint(low=0, high=len(Training_data_64))
2 print(f'sample number: {sample_index}') 2 print(f'sample number: {sample_index}')
----> 3 predict_stress(sample_index, normalization = normalization_64, model = model_15, dataset = Training_data_64,grain_data =grain_data_64) ----> 3 predict_stress(sample_index, normalization = normalization_64, model = model_15, dataset = Training_data_64,grain_data =grain_data_64)
~\AppData\Local\Temp/ipykernel_23188/2786384287.py in predict_stress(image_id, normalization, model, dataset, grain_data, threshold) ~\AppData\Local\Temp/ipykernel_23188/2786384287.py in predict_stress(image_id, normalization, model, dataset, grain_data, threshold)
11 xb = UNet15.to_device(input, device_15) 11 xb = UNet15.to_device(input, device_15)
12 model.eval() 12 model.eval()
---> 13 prediction = model(xb) ---> 13 prediction = model(xb)
14 input = input.detach().numpy() 14 input = input.detach().numpy()
15 prediction = prediction.detach().numpy() 15 prediction = prediction.detach().numpy()
~\Miniconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs) ~\Miniconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks): 1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs) -> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used 1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], [] 1104 full_backward_hooks, non_full_backward_hooks = [], []
e:\Data\damask3\UNet\UNet_V15.py in forward(self, x) e:\Data\damask3\UNet\UNet_V15.py in forward(self, x)
149 def forward(self, x): 149 def forward(self, x):
150 enc_ftrs = self.encoder(x) 150 enc_ftrs = self.encoder(x)
--> 151 out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:]) --> 151 out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
152 #out = self.head(out) 152 #out = self.head(out)
153 return out 153 return out
~\Miniconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs) ~\Miniconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks): 1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs) -> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used 1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], [] 1104 full_backward_hooks, non_full_backward_hooks = [], []
e:\Data\damask3\UNet\UNet_V15.py in forward(self, x, encoder_features) e:\Data\damask3\UNet\UNet_V15.py in forward(self, x, encoder_features)
95 #print(f'size after cropping&cat: {x.size()}') 95 #print(f'size after cropping&cat: {x.size()}')
96 96
---> 97 x = self.dec_blocks[i](x) ---> 97 x = self.dec_blocks[i](x)
98 #print(f'size after convolution: {x.size()}') 98 #print(f'size after convolution: {x.size()}')
99 x = self.head(x) 99 x = self.head(x)
~\Miniconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs) ~\Miniconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks): 1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs) -> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used 1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], [] 1104 full_backward_hooks, non_full_backward_hooks = [], []
e:\Data\damask3\UNet\UNet_V15.py in forward(self, x) e:\Data\damask3\UNet\UNet_V15.py in forward(self, x)
31 self.batch_norm_2 = nn.BatchNorm3d(out_2_c) 31 self.batch_norm_2 = nn.BatchNorm3d(out_2_c)
32 def forward(self, x): 32 def forward(self, x):
---> 33 x = self.batch_norm_1(self.relu(self.droptout(self.pointwise_1(self.depthwise_1(x))))) ---> 33 x = self.batch_norm_1(self.relu(self.droptout(self.pointwise_1(self.depthwise_1(x)))))
34 return self.batch_norm_2(self.relu(self.droptout(self.pointwise_2(self.depthwise_2(x))))) 34 return self.batch_norm_2(self.relu(self.droptout(self.pointwise_2(self.depthwise_2(x)))))
35 35
~\Miniconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs) ~\Miniconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks): 1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs) -> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used 1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], [] 1104 full_backward_hooks, non_full_backward_hooks = [], []
~\Miniconda3\lib\site-packages\torch\nn\modules\conv.py in forward(self, input) ~\Miniconda3\lib\site-packages\torch\nn\modules\conv.py in forward(self, input)
588 588
589 def forward(self, input: Tensor) -> Tensor: 589 def forward(self, input: Tensor) -> Tensor:
--> 590 return self._conv_forward(input, self.weight, self.bias) --> 590 return self._conv_forward(input, self.weight, self.bias)
591 591
592 592
~\Miniconda3\lib\site-packages\torch\nn\modules\conv.py in _conv_forward(self, input, weight, bias) ~\Miniconda3\lib\site-packages\torch\nn\modules\conv.py in _conv_forward(self, input, weight, bias)
583 self.groups, 583 self.groups,
584 ) 584 )
--> 585 return F.conv3d( --> 585 return F.conv3d(
586 input, weight, bias, self.stride, self.padding, self.dilation, self.groups 586 input, weight, bias, self.stride, self.padding, self.dilation, self.groups
587 ) 587 )
RuntimeError: [enforce fail at ..\c10\core\CPUAllocator.cpp:76] data. DefaultCPUAllocator: not enough memory: you tried to allocate 14386462720 bytes. RuntimeError: [enforce fail at ..\c10\core\CPUAllocator.cpp:76] data. DefaultCPUAllocator: not enough memory: you tried to allocate 14386462720 bytes.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
mean_error, max_error, correct_per = dataset_evaluation( normalization = normalization, model = model, dataset = Training_data, threshold = 0.1) mean_error, max_error, correct_per = dataset_evaluation( normalization = normalization, model = model, dataset = Training_data, threshold = 0.1)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def dataset_evaluation( normalization = normalization, model = model, dataset = Training_data, threshold = 0.05): def dataset_evaluation( normalization = normalization, model = model, dataset = Training_data, threshold = 0.05):
model.eval() model.eval()
mean_error = np.empty(len(dataset)) mean_error = np.empty(len(dataset))
max_error = np.empty(len(dataset)) max_error = np.empty(len(dataset))
correct_per = np.empty(len(dataset)) #percentage of voxel that are guessed corrected, according to threshold correct_per = np.empty(len(dataset)) #percentage of voxel that are guessed corrected, according to threshold
for index in range(len(dataset)): for index in range(len(dataset)):
input, output = dataset[index] input, output = dataset[index]
input = copy.copy(input) input = copy.copy(input)
output = copy.copy(output) output = copy.copy(output)
input = torch.unsqueeze(input,0) input = torch.unsqueeze(input,0)
output = torch.unsqueeze(output,0) output = torch.unsqueeze(output,0)
xb = UNet.to_device(input, device) xb = UNet.to_device(input, device)
prediction = model(xb) prediction = model(xb)
input = input.detach().numpy() input = input.detach().numpy()
prediction = prediction.detach().numpy() prediction = prediction.detach().numpy()
output = output.detach().numpy() output = output.detach().numpy()
prediction = rescale(prediction, normalization) prediction = rescale(prediction, normalization)
output = rescale(output, normalization) output = rescale(output, normalization)
error = (abs(output - prediction)/output) error = (abs(output - prediction)/output)
right_predic = (error < threshold).sum() right_predic = (error < threshold).sum()
mean_error[index] = error.mean()*100. mean_error[index] = error.mean()*100.
max_error[index] = error.max()*100. max_error[index] = error.max()*100.
correct_per[index] = right_predic * 100. correct_per[index] = right_predic * 100.
return mean_error, max_error, correct_per return mean_error, max_error, correct_per
``` ```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment