Skip to content
Snippets Groups Projects
Commit 722c66c7 authored by Ali Darijani's avatar Ali Darijani
Browse files

commit

parent 31abb105
No related branches found
No related tags found
No related merge requests found
import torch
import numpy as np
import matplotlib.pyplot as plt
from skimage import io
import math
import time
import pprint
# import pathlib
import datetime
from datetime import timedelta
start = time.time()
cuda_opt = True
if torch.cuda.is_available() & cuda_opt:
DEVICE = "cuda"
else:
DEVICE = "cpu"
image_path = '/mnt/Data/darijani/AdaLISTA/natural_images/stanford.jpg'
X = (torch.from_numpy(io.imread(image_path)).to(torch.cdouble)).to(DEVICE)
# print(image)
n1, n2 = X.shape[:2]
# print(height,width)
# Make masks and linear sampling operators
# # Each mask has iid entries following the octanary pattern in which the entries are
# # distributed as b1 x b2 where
# # b1 is uniform over {1, -1, i, -i} (phase)
# # b2 is equal to 1/sqrt(2) with prob. 4/5 and sqrt(3) with prob. 1/5 (magnitude)
# # Number of masks
L = 21
# # Storage for L masks, each of dim n1 x n2
Masks_1 = torch.zeros(n1,n2,L,device=DEVICE,dtype=torch.cdouble)
# # for ll in range(L):
Masks_1 = torch.from_numpy(np.random.choice([1,-1,1j,-1j], (n1,n2,L), p=[1/4, 1/4, 1/4, 1/4])).to(DEVICE)
Masks_2 = torch.from_numpy(np.random.choice([1/math.sqrt(2),math.sqrt(3)], (n1,n2,L), p=[4/5, 1/5])).to(DEVICE)
Masks = torch.mul(Masks_1,Masks_2)
print(f"Masks_1 is on: {Masks_1.device}")
print(f"Masks_2 is on: {Masks_2.device}")
print(f"Masks is on: {Masks.device}")
# # Make Linear Operators
def A(I):
conj = torch.conj_physical(Masks)
repeat = torch.tile(I, (1, L))
reshape = torch.reshape(repeat,[I.size(0), I.size(1),L])
hadamard = torch.mul(conj,reshape)
result = torch.fft.fft2(hadamard)
return result
def At(Y):
inverse_fft =torch.fft.ifft2(Y)
mul = torch.mul(inverse_fft,Masks)
sum = torch.sum(mul,2)
normalized = sum * Y.size(0) * Y.size(1)
result = normalized
return result
# # Prepare structure to save intermediate results
# ttimes = [10,300]; % Iterations at which we will save info
# ntimes = length(ttimes)+1; % +1 because we will save info after the initialization
# # Xhats = cell(1,ntimes);
# # for mm = 1:ntimes, Xhats{mm} = zeros(size(X)); end
# # Times = zeros(3,ntimes);
ttimes = [10,30] # Iterations at which we will save info
ntimes = len(ttimes)+1; # +1 because we will save info after the initialization
Xhats = torch.zeros([ntimes,n1,n2,3],dtype=torch.cdouble,device=DEVICE)
Times = torch.zeros(3,ntimes)
# # Wirtinger flow
npower_iter = 50 # Number of power iterations
T = max(ttimes) # Max number of iterations
tau0 = 330 # Time constant for step size
def mu(t):
return min(1-math.exp(-t/tau0), 0.4) # Schedule for step size
for rgb in range(0,X.shape[-1]):
# print(rgb)
print(f'Color band {rgb}')
x = torch.squeeze(X[:,:,rgb]); # Image x is n1 x n2
Y = torch.abs(A(x))**2; # Measured data
z0 = torch.randn(n1,n2,device=DEVICE)
z0 = z0/torch.linalg.norm(z0,'fro') # Initial guess
# start_init = time.time()
for tt in range(0,npower_iter):
z0 = At(A(z0))
z0 = z0/torch.norm(z0,'fro')
# end_init = time.time()
# elapsed_time_init = end_init - start_init
# print("elapsed time is:")
# print(str(timedelta(seconds=elapsed_time_init)),"(HH:MM:SS)")
normest = torch.sqrt(torch.sum(Y)/torch.numel(Y)) # Estimate norm to scale eigenvector
z = normest * z0 # Apply scaling
# print(x)
# print(z)
# print(z.shape)
# Xhats[:,:,rgb] = torch.exp(-1j*torch.angle(torch.trace(torch.matmul(x.H,z))))*z
# Xhats[:,:,rgb] = torch.exp(-1j*torch.angle() * z) # Initial guess after global phase adjustment
# Loop
print(f'Done with initialization, starting loop\n')
for t in range(1,T):
print(t)
Bz = A(z)
C = torch.abs(Bz)**2-Y * Bz
grad = At(C)/torch.numel(C) # Wirtinger gradient
z = z - mu(t)/normest**2 * grad # Gradient update
if t == T-2:
Xhats[0,:,:,rgb] = torch.exp(-1j*torch.angle(torch.trace(torch.matmul(x.H,z))))*z
# print(Xhats[0,:,:,rgb].shape)
# print((torch.exp(-1j*torch.angle(torch.trace(torch.matmul(x.H,z))))*z).shape)
# print(Xhats.shape)
# print(ttimes)
# red_image = image[:,:,0]
# daten = np.tile(red_image, (1, L))
# print(daten.shape)
# print(red_image.repeat(1,L).shape)
# print(red_image.shape)
# print(mu(4))
# At(A(torch.from_numpy(image[:,:,0]).to(DEVICE)))
# np_arr = X.detach().cpu().numpy()
# image = np_arr.astype(int)
# # print(type(X))
# # print(X)
# plt.imshow(image)
# plt.axis('off')
# plt.show()
# io.imsave("/mnt/Data/darijani/AdaLISTA/natural_images/reconstructed.png",image)
end = time.time()
elapsed_time = end - start
print("total elapsed time is:")
print(str(timedelta(seconds=elapsed_time)),"(HH:MM:SS)")
\ No newline at end of file
natural_images/saved.png

683 KiB

%% Cell type:code id: tags:
``` python
import skimage.io
import torch
import torch.fft as fft
from skimage import color
# Implementation of the Wirtinger Flow (WF) algorithm
def wirtinger_flow(X, L=21, npower_iter=50, T=300, tau0=330):
# X: n1 x n2 x 3 tensor representing RGB image
n1, n2, _ = X.shape
# Make masks and linear sampling operators
Masks = torch.zeros(n1, n2, L, dtype=torch.cfloat)
for ll in range(L):
phase = torch.choice([1j, -1j, 1, -1], (n1, n2))
magnitude = torch.where(torch.rand(n1, n2) <= 0.2, torch.sqrt(torch.tensor(3.0)), torch.tensor(1 / torch.sqrt(2)))
Masks[:, :, ll] = phase * magnitude
A = lambda I: fft.fft2(torch.conj(Masks) * I.view(n1, n2, 1).repeat(1, 1, L))
At = lambda Y: torch.sum(Masks * fft.ifft2(Y), dim=2) * n1 * n2
# Prepare structure to save intermediate results
ttimes = [150, 300]
ntimes = len(ttimes) + 1
Xhats = [torch.zeros(n1, n2, 3) for _ in range(ntimes)]
Times = torch.zeros(3, ntimes)
for rgb in range(3):
print(f'Color band {rgb+1}')
x = X[:, :, rgb]
Y = torch.abs(A(x))**2
# Initialization
z0 = torch.randn(n1, n2, dtype=torch.cfloat)
z0 = z0 / torch.norm(z0, 'fro')
for tt in range(npower_iter):
z0 = At(A(z0))
z0 = z0 / torch.norm(z0, 'fro')
Times[rgb, 0] = 0.0
normest = torch.sqrt(torch.sum(Y) / Y.numel())
z = normest * z0
Xhats[0][:, :, rgb] = torch.exp(-1j * torch.angle(torch.trace(x.T @ z))) * z
# Loop
print('Done with initialization, starting loop')
for t in range(1, T + 1):
Bz = A(z)
C = (torch.abs(Bz)**2 - Y) * Bz
grad = At(C) / C.numel() # Wirtinger gradient
z = z - mu(t, tau0) / normest**2 * grad # Gradient update
if t in ttimes:
ind = ttimes.index(t) + 1
Xhats[ind][:, :, rgb] = torch.exp(-1j * torch.angle(torch.trace(x.T @ z))) * z
Times[rgb, ind] = 0.0
print('All done!')
return Xhats, Times
def mu(t, tau0):
return min(1 - torch.exp(-t / tau0), 0.4)
if __name__ == "__main__":
# Read the RGB image (you need to provide the path to the image)
image_path = '/mnt/Data/darijani/AdaLISTA/natural_images/galaxy.jpg'
image = skimage.io.imread(image_path)
# % Below X is n1 x n2 x 3; i.e. we have three n1 x n2 images, one for each of the 3 color channels
n1, n2 = image.shape[:2]
# namestr = 'stanford'
# stanstr = 'jpg'
X = torch.tensor(color.rgb2gray(image), dtype=torch.float32)
# Apply the Wirtinger Flow algorithm
Xhats, Times = wirtinger_flow(X)
# Show some results
# for tt in range(len(Xhats)):
# for rgb in range(3):
# imshow(Xhats[tt][:, :, rgb].abs(), cmap='gray')
```
%% Output
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[6], line 74
71 X = torch.tensor(color.rgb2gray(image), dtype=torch.float32)
73 # Apply the Wirtinger Flow algorithm
---> 74 Xhats, Times = wirtinger_flow(X)
76 # Show some results
77 # for tt in range(len(Xhats)):
78 # for rgb in range(3):
79 # imshow(Xhats[tt][:, :, rgb].abs(), cmap='gray')
Cell In[6], line 9, in wirtinger_flow(X, L, npower_iter, T, tau0)
7 def wirtinger_flow(X, L=21, npower_iter=50, T=300, tau0=330):
8 # X: n1 x n2 x 3 tensor representing RGB image
----> 9 n1, n2, _ = X.shape
11 # Make masks and linear sampling operators
12 Masks = torch.zeros(n1, n2, L, dtype=torch.cfloat)
ValueError: not enough values to unpack (expected 3, got 2)
%% Cell type:code id: tags:
``` python
```
import torch
import torch.fft as fft
import math
# Function to create random masks
def create_masks(n1, n2, L):
masks = torch.randint(0, 4, (n1, n2, L))
masks = torch.where(masks == 0, torch.tensor(1j), masks)
masks = torch.where(masks == 1, torch.tensor(-1j), masks)
masks = torch.where(masks == 2, torch.tensor(1), masks)
masks = torch.where(masks == 3, torch.tensor(-1), masks)
temp = torch.rand(n1, n2, L)
masks = masks * torch.where(temp <= 0.2, torch.tensor(math.sqrt(3)), torch.tensor(1 / math.sqrt(2)))
return masks
# Function to compute the linear operator A
def linear_operator_A(I, masks):
masks_conj = torch.conj(masks)
repeated_I = I.unsqueeze(2).repeat(1, 1, masks.shape[2])
F_I = fft.fft2(repeated_I)
A_I = F_I * masks_conj
return A_I
# Function to compute the adjoint linear operator A^T
def linear_operator_At(Y, masks):
masks_Y = masks * fft.ifft2(Y)
return masks_Y.sum(dim=2) * Y.shape[1] * Y.shape[2]
# Wirtinger flow implementation
def wirtinger_flow(X, L, npower_iter, ttimes, tau0):
n1, n2, _ = X.shape
times = [150, 300]
ntimes = len(times) + 1
Xhats = [torch.zeros_like(X,dtype=torch.cdouble) for _ in range(ntimes)]
for mm in range(ntimes):
Xhats[mm] = X.clone()
Times = torch.zeros(3, ntimes)
for rgb in range(3):
print(f'Color band {rgb + 1}')
x = X[:, :, rgb] # Image x is n1 x n2
Y = torch.abs(linear_operator_A(x, masks)) ** 2 # Measured data
# Initialization
z0 = torch.randn(n1, n2).to(x.device)
z0 = z0 / torch.norm(z0, 'fro')
for tt in range(npower_iter):
z0 = linear_operator_At(linear_operator_A(z0, masks), masks)
z0 = z0 / torch.norm(z0, 'fro')
Times[rgb, 0] = 0
normest = torch.sqrt(Y.sum() / Y.numel()) # Estimate norm to scale eigenvector
z = normest * z0 # Apply scaling
Xhats[0][:, :, rgb] = torch.exp(-1j * torch.angle(torch.trace(x.t() @ z))) * z # Initial guess after global phase adjustment
# Loop
print('Done with initialization, starting loop')
for t in range(T):
Bz = linear_operator_A(z, masks)
C = (torch.abs(Bz) ** 2 - Y) * Bz
grad = linear_operator_At(C / C.numel(), masks) # Wirtinger gradient
z = z - mu(t) / (normest ** 2) * grad # Gradient update
ind = ttimes.index(t) if t in ttimes else -1
if ind >= 0:
Xhats[ind + 1][:, :, rgb] = torch.exp(-1j * torch.angle(torch.trace(x.t() @ z))) * z
Times[rgb, ind + 1] = 0
print('All done!')
return Xhats, Times
# Example usage
if __name__ == '__main__':
import numpy as np
from PIL import Image
# Read Image
namestr = 'stanford'
stanstr = 'jpg'
img_path = f'{namestr}.{stanstr}'
img = Image.open(img_path)
X = torch.tensor(np.array(img)).float() / 255.0
n1, n2 = X.shape[:2]
# Define parameters
L = 21
npower_iter = 50
ttimes = [150, 300]
tau0 = 330
# Create masks
masks = create_masks(n1, n2, L)
# Perform Wirtinger Flow
Xhats, Times = wirtinger_flow(X, L, npower_iter, ttimes, tau0)
# Show some results
iter = [0] + ttimes
Relerrs = []
for mm in range(len(Xhats)):
mean_time = torch.mean(Times[:, mm])
Relerr = torch.norm(Xhats[mm].flatten() - X.flatten()) / torch.norm(X.flatten())
Relerrs.append(Relerr.item())
print(f"Mean running times after {iter[mm]} iterations: {mean_time:.1f}")
print(f"Relative error after {iter[mm]} iterations: {Relerr:.5f}\n")
# Display the resulting images
for tt in range(len(Xhats)):
img = Image.fromarray(np.uint8(Xhats[tt].numpy() * 255))
img.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment