diff --git a/natural_images/main.py b/natural_images/main.py deleted file mode 100644 index 5d64ed31d3bf394afb16164e50f2c97105066d5b..0000000000000000000000000000000000000000 --- a/natural_images/main.py +++ /dev/null @@ -1,154 +0,0 @@ -import torch -import numpy as np -import matplotlib.pyplot as plt -from skimage import io -import math -import time -import pprint -# import pathlib -import datetime -from datetime import timedelta -start = time.time() -cuda_opt = True -if torch.cuda.is_available() & cuda_opt: - DEVICE = "cuda" -else: - DEVICE = "cpu" - -image_path = '/mnt/Data/darijani/AdaLISTA/natural_images/stanford.jpg' -X = (torch.from_numpy(io.imread(image_path)).to(torch.cdouble)).to(DEVICE) - -# print(image) -n1, n2 = X.shape[:2] -# print(height,width) -# Make masks and linear sampling operators - -# # Each mask has iid entries following the octanary pattern in which the entries are -# # distributed as b1 x b2 where -# # b1 is uniform over {1, -1, i, -i} (phase) -# # b2 is equal to 1/sqrt(2) with prob. 4/5 and sqrt(3) with prob. 1/5 (magnitude) - -# # Number of masks -L = 21 -# # Storage for L masks, each of dim n1 x n2 -Masks_1 = torch.zeros(n1,n2,L,device=DEVICE,dtype=torch.cdouble) -# # for ll in range(L): -Masks_1 = torch.from_numpy(np.random.choice([1,-1,1j,-1j], (n1,n2,L), p=[1/4, 1/4, 1/4, 1/4])).to(DEVICE) -Masks_2 = torch.from_numpy(np.random.choice([1/math.sqrt(2),math.sqrt(3)], (n1,n2,L), p=[4/5, 1/5])).to(DEVICE) -Masks = torch.mul(Masks_1,Masks_2) -print(f"Masks_1 is on: {Masks_1.device}") -print(f"Masks_2 is on: {Masks_2.device}") -print(f"Masks is on: {Masks.device}") - - - -# # Make Linear Operators -def A(I): - conj = torch.conj_physical(Masks) - repeat = torch.tile(I, (1, L)) - reshape = torch.reshape(repeat,[I.size(0), I.size(1),L]) - hadamard = torch.mul(conj,reshape) - result = torch.fft.fft2(hadamard) - return result - -def At(Y): - inverse_fft =torch.fft.ifft2(Y) - mul = torch.mul(inverse_fft,Masks) - sum = torch.sum(mul,2) - normalized = sum * Y.size(0) * Y.size(1) - result = normalized - return result - -# # Prepare structure to save intermediate results - -# ttimes = [10,300]; % Iterations at which we will save info -# ntimes = length(ttimes)+1; % +1 because we will save info after the initialization -# # Xhats = cell(1,ntimes); -# # for mm = 1:ntimes, Xhats{mm} = zeros(size(X)); end -# # Times = zeros(3,ntimes); -ttimes = [10,30] # Iterations at which we will save info -ntimes = len(ttimes)+1; # +1 because we will save info after the initialization - -Xhats = torch.zeros([ntimes,n1,n2,3],dtype=torch.cdouble,device=DEVICE) -Times = torch.zeros(3,ntimes) - -# # Wirtinger flow - -npower_iter = 50 # Number of power iterations -T = max(ttimes) # Max number of iterations -tau0 = 330 # Time constant for step size -def mu(t): - return min(1-math.exp(-t/tau0), 0.4) # Schedule for step size - - - -for rgb in range(0,X.shape[-1]): - # print(rgb) - print(f'Color band {rgb}') - x = torch.squeeze(X[:,:,rgb]); # Image x is n1 x n2 - Y = torch.abs(A(x))**2; # Measured data - z0 = torch.randn(n1,n2,device=DEVICE) - z0 = z0/torch.linalg.norm(z0,'fro') # Initial guess - # start_init = time.time() - for tt in range(0,npower_iter): - z0 = At(A(z0)) - z0 = z0/torch.norm(z0,'fro') - # end_init = time.time() - # elapsed_time_init = end_init - start_init - # print("elapsed time is:") - # print(str(timedelta(seconds=elapsed_time_init)),"(HH:MM:SS)") - - normest = torch.sqrt(torch.sum(Y)/torch.numel(Y)) # Estimate norm to scale eigenvector - z = normest * z0 # Apply scaling - # print(x) - # print(z) - - # print(z.shape) - # Xhats[:,:,rgb] = torch.exp(-1j*torch.angle(torch.trace(torch.matmul(x.H,z))))*z - # Xhats[:,:,rgb] = torch.exp(-1j*torch.angle() * z) # Initial guess after global phase adjustment - # Loop - print(f'Done with initialization, starting loop\n') - - for t in range(1,T): - print(t) - Bz = A(z) - C = torch.abs(Bz)**2-Y * Bz - grad = At(C)/torch.numel(C) # Wirtinger gradient - z = z - mu(t)/normest**2 * grad # Gradient update - if t == T-2: - Xhats[0,:,:,rgb] = torch.exp(-1j*torch.angle(torch.trace(torch.matmul(x.H,z))))*z - # print(Xhats[0,:,:,rgb].shape) - # print((torch.exp(-1j*torch.angle(torch.trace(torch.matmul(x.H,z))))*z).shape) - - - - - - - - -# print(Xhats.shape) -# print(ttimes) -# red_image = image[:,:,0] -# daten = np.tile(red_image, (1, L)) -# print(daten.shape) -# print(red_image.repeat(1,L).shape) -# print(red_image.shape) -# print(mu(4)) -# At(A(torch.from_numpy(image[:,:,0]).to(DEVICE))) -# np_arr = X.detach().cpu().numpy() -# image = np_arr.astype(int) -# # print(type(X)) -# # print(X) -# plt.imshow(image) -# plt.axis('off') -# plt.show() - - -# io.imsave("/mnt/Data/darijani/AdaLISTA/natural_images/reconstructed.png",image) - - -end = time.time() -elapsed_time = end - start -print("total elapsed time is:") -print(str(timedelta(seconds=elapsed_time)),"(HH:MM:SS)") \ No newline at end of file diff --git a/natural_images/saved.png b/natural_images/saved.png deleted file mode 100644 index a6eeef19ac719d1836b7b13d97b72ec0448cdc14..0000000000000000000000000000000000000000 Binary files a/natural_images/saved.png and /dev/null differ diff --git a/natural_images/test.ipynb b/natural_images/test.ipynb deleted file mode 100644 index e13c6aec145977a06e85cae87f20d1f5bb131743..0000000000000000000000000000000000000000 --- a/natural_images/test.ipynb +++ /dev/null @@ -1,133 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "ename": "ValueError", - "evalue": "not enough values to unpack (expected 3, got 2)", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[6], line 74\u001b[0m\n\u001b[1;32m 71\u001b[0m X \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mtensor(color\u001b[39m.\u001b[39mrgb2gray(image), dtype\u001b[39m=\u001b[39mtorch\u001b[39m.\u001b[39mfloat32)\n\u001b[1;32m 73\u001b[0m \u001b[39m# Apply the Wirtinger Flow algorithm\u001b[39;00m\n\u001b[0;32m---> 74\u001b[0m Xhats, Times \u001b[39m=\u001b[39m wirtinger_flow(X)\n\u001b[1;32m 76\u001b[0m \u001b[39m# Show some results\u001b[39;00m\n\u001b[1;32m 77\u001b[0m \u001b[39m# for tt in range(len(Xhats)):\u001b[39;00m\n\u001b[1;32m 78\u001b[0m \u001b[39m# for rgb in range(3):\u001b[39;00m\n\u001b[1;32m 79\u001b[0m \u001b[39m# imshow(Xhats[tt][:, :, rgb].abs(), cmap='gray')\u001b[39;00m\n", - "Cell \u001b[0;32mIn[6], line 9\u001b[0m, in \u001b[0;36mwirtinger_flow\u001b[0;34m(X, L, npower_iter, T, tau0)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mwirtinger_flow\u001b[39m(X, L\u001b[39m=\u001b[39m\u001b[39m21\u001b[39m, npower_iter\u001b[39m=\u001b[39m\u001b[39m50\u001b[39m, T\u001b[39m=\u001b[39m\u001b[39m300\u001b[39m, tau0\u001b[39m=\u001b[39m\u001b[39m330\u001b[39m):\n\u001b[1;32m 8\u001b[0m \u001b[39m# X: n1 x n2 x 3 tensor representing RGB image\u001b[39;00m\n\u001b[0;32m----> 9\u001b[0m n1, n2, _ \u001b[39m=\u001b[39m X\u001b[39m.\u001b[39mshape\n\u001b[1;32m 11\u001b[0m \u001b[39m# Make masks and linear sampling operators\u001b[39;00m\n\u001b[1;32m 12\u001b[0m Masks \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mzeros(n1, n2, L, dtype\u001b[39m=\u001b[39mtorch\u001b[39m.\u001b[39mcfloat)\n", - "\u001b[0;31mValueError\u001b[0m: not enough values to unpack (expected 3, got 2)" - ] - } - ], - "source": [ - "import skimage.io\n", - "import torch\n", - "import torch.fft as fft\n", - "from skimage import color\n", - "\n", - "# Implementation of the Wirtinger Flow (WF) algorithm\n", - "def wirtinger_flow(X, L=21, npower_iter=50, T=300, tau0=330):\n", - " # X: n1 x n2 x 3 tensor representing RGB image\n", - " n1, n2, _ = X.shape\n", - "\n", - " # Make masks and linear sampling operators\n", - " Masks = torch.zeros(n1, n2, L, dtype=torch.cfloat)\n", - " for ll in range(L):\n", - " phase = torch.choice([1j, -1j, 1, -1], (n1, n2))\n", - " magnitude = torch.where(torch.rand(n1, n2) <= 0.2, torch.sqrt(torch.tensor(3.0)), torch.tensor(1 / torch.sqrt(2)))\n", - " Masks[:, :, ll] = phase * magnitude\n", - "\n", - " A = lambda I: fft.fft2(torch.conj(Masks) * I.view(n1, n2, 1).repeat(1, 1, L))\n", - " At = lambda Y: torch.sum(Masks * fft.ifft2(Y), dim=2) * n1 * n2\n", - "\n", - " # Prepare structure to save intermediate results\n", - " ttimes = [150, 300]\n", - " ntimes = len(ttimes) + 1\n", - " Xhats = [torch.zeros(n1, n2, 3) for _ in range(ntimes)]\n", - " Times = torch.zeros(3, ntimes)\n", - "\n", - " for rgb in range(3):\n", - " print(f'Color band {rgb+1}')\n", - " x = X[:, :, rgb]\n", - " Y = torch.abs(A(x))**2\n", - "\n", - " # Initialization\n", - " z0 = torch.randn(n1, n2, dtype=torch.cfloat)\n", - " z0 = z0 / torch.norm(z0, 'fro')\n", - " for tt in range(npower_iter):\n", - " z0 = At(A(z0))\n", - " z0 = z0 / torch.norm(z0, 'fro')\n", - " Times[rgb, 0] = 0.0\n", - "\n", - " normest = torch.sqrt(torch.sum(Y) / Y.numel())\n", - " z = normest * z0\n", - " Xhats[0][:, :, rgb] = torch.exp(-1j * torch.angle(torch.trace(x.T @ z))) * z\n", - "\n", - " # Loop\n", - " print('Done with initialization, starting loop')\n", - " for t in range(1, T + 1):\n", - " Bz = A(z)\n", - " C = (torch.abs(Bz)**2 - Y) * Bz\n", - " grad = At(C) / C.numel() # Wirtinger gradient\n", - " z = z - mu(t, tau0) / normest**2 * grad # Gradient update\n", - "\n", - " if t in ttimes:\n", - " ind = ttimes.index(t) + 1\n", - " Xhats[ind][:, :, rgb] = torch.exp(-1j * torch.angle(torch.trace(x.T @ z))) * z\n", - " Times[rgb, ind] = 0.0\n", - "\n", - " print('All done!')\n", - " return Xhats, Times\n", - "\n", - "def mu(t, tau0):\n", - " return min(1 - torch.exp(-t / tau0), 0.4)\n", - "\n", - "if __name__ == \"__main__\":\n", - " # Read the RGB image (you need to provide the path to the image)\n", - " image_path = '/mnt/Data/darijani/AdaLISTA/natural_images/galaxy.jpg'\n", - " image = skimage.io.imread(image_path)\n", - "# % Below X is n1 x n2 x 3; i.e. we have three n1 x n2 images, one for each of the 3 color channels \n", - " n1, n2 = image.shape[:2]\n", - " # namestr = 'stanford'\n", - " # stanstr = 'jpg'\n", - " X = torch.tensor(color.rgb2gray(image), dtype=torch.float32)\n", - "\n", - " # Apply the Wirtinger Flow algorithm\n", - " Xhats, Times = wirtinger_flow(X)\n", - "\n", - " # Show some results\n", - " # for tt in range(len(Xhats)):\n", - " # for rgb in range(3):\n", - " # imshow(Xhats[tt][:, :, rgb].abs(), cmap='gray')\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "torch", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.9" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/natural_images/test.py b/natural_images/test.py deleted file mode 100644 index 5c76730296877ef8dac1f2ac81c3ef817aac7aba..0000000000000000000000000000000000000000 --- a/natural_images/test.py +++ /dev/null @@ -1,111 +0,0 @@ -import torch -import torch.fft as fft -import math - -# Function to create random masks -def create_masks(n1, n2, L): - masks = torch.randint(0, 4, (n1, n2, L)) - masks = torch.where(masks == 0, torch.tensor(1j), masks) - masks = torch.where(masks == 1, torch.tensor(-1j), masks) - masks = torch.where(masks == 2, torch.tensor(1), masks) - masks = torch.where(masks == 3, torch.tensor(-1), masks) - temp = torch.rand(n1, n2, L) - masks = masks * torch.where(temp <= 0.2, torch.tensor(math.sqrt(3)), torch.tensor(1 / math.sqrt(2))) - return masks - -# Function to compute the linear operator A -def linear_operator_A(I, masks): - masks_conj = torch.conj(masks) - repeated_I = I.unsqueeze(2).repeat(1, 1, masks.shape[2]) - F_I = fft.fft2(repeated_I) - A_I = F_I * masks_conj - return A_I - -# Function to compute the adjoint linear operator A^T -def linear_operator_At(Y, masks): - masks_Y = masks * fft.ifft2(Y) - return masks_Y.sum(dim=2) * Y.shape[1] * Y.shape[2] - -# Wirtinger flow implementation -def wirtinger_flow(X, L, npower_iter, ttimes, tau0): - n1, n2, _ = X.shape - times = [150, 300] - ntimes = len(times) + 1 - Xhats = [torch.zeros_like(X,dtype=torch.cdouble) for _ in range(ntimes)] - for mm in range(ntimes): - Xhats[mm] = X.clone() - - Times = torch.zeros(3, ntimes) - - for rgb in range(3): - print(f'Color band {rgb + 1}') - x = X[:, :, rgb] # Image x is n1 x n2 - Y = torch.abs(linear_operator_A(x, masks)) ** 2 # Measured data - - # Initialization - z0 = torch.randn(n1, n2).to(x.device) - z0 = z0 / torch.norm(z0, 'fro') - for tt in range(npower_iter): - z0 = linear_operator_At(linear_operator_A(z0, masks), masks) - z0 = z0 / torch.norm(z0, 'fro') - Times[rgb, 0] = 0 - - normest = torch.sqrt(Y.sum() / Y.numel()) # Estimate norm to scale eigenvector - z = normest * z0 # Apply scaling - Xhats[0][:, :, rgb] = torch.exp(-1j * torch.angle(torch.trace(x.t() @ z))) * z # Initial guess after global phase adjustment - - # Loop - print('Done with initialization, starting loop') - for t in range(T): - Bz = linear_operator_A(z, masks) - C = (torch.abs(Bz) ** 2 - Y) * Bz - grad = linear_operator_At(C / C.numel(), masks) # Wirtinger gradient - z = z - mu(t) / (normest ** 2) * grad # Gradient update - - ind = ttimes.index(t) if t in ttimes else -1 - if ind >= 0: - Xhats[ind + 1][:, :, rgb] = torch.exp(-1j * torch.angle(torch.trace(x.t() @ z))) * z - Times[rgb, ind + 1] = 0 - - print('All done!') - return Xhats, Times - -# Example usage -if __name__ == '__main__': - import numpy as np - from PIL import Image - - # Read Image - namestr = 'stanford' - stanstr = 'jpg' - img_path = f'{namestr}.{stanstr}' - img = Image.open(img_path) - X = torch.tensor(np.array(img)).float() / 255.0 - n1, n2 = X.shape[:2] - - # Define parameters - L = 21 - npower_iter = 50 - ttimes = [150, 300] - tau0 = 330 - - # Create masks - masks = create_masks(n1, n2, L) - - # Perform Wirtinger Flow - Xhats, Times = wirtinger_flow(X, L, npower_iter, ttimes, tau0) - - # Show some results - iter = [0] + ttimes - Relerrs = [] - for mm in range(len(Xhats)): - mean_time = torch.mean(Times[:, mm]) - Relerr = torch.norm(Xhats[mm].flatten() - X.flatten()) / torch.norm(X.flatten()) - Relerrs.append(Relerr.item()) - print(f"Mean running times after {iter[mm]} iterations: {mean_time:.1f}") - print(f"Relative error after {iter[mm]} iterations: {Relerr:.5f}\n") - - # Display the resulting images - for tt in range(len(Xhats)): - img = Image.fromarray(np.uint8(Xhats[tt].numpy() * 255)) - img.show()