diff --git a/natural_images/main.py b/natural_images/main.py
deleted file mode 100644
index 5d64ed31d3bf394afb16164e50f2c97105066d5b..0000000000000000000000000000000000000000
--- a/natural_images/main.py
+++ /dev/null
@@ -1,154 +0,0 @@
-import torch
-import numpy as np
-import matplotlib.pyplot as plt
-from skimage import io
-import math
-import time
-import pprint
-# import pathlib
-import datetime
-from datetime import timedelta
-start = time.time()
-cuda_opt = True
-if torch.cuda.is_available() & cuda_opt:
-    DEVICE = "cuda"
-else:
-    DEVICE = "cpu" 
-
-image_path = '/mnt/Data/darijani/AdaLISTA/natural_images/stanford.jpg'
-X = (torch.from_numpy(io.imread(image_path)).to(torch.cdouble)).to(DEVICE)
-
-# print(image)
-n1, n2 = X.shape[:2]
-# print(height,width)
-# Make masks and linear sampling operators  
-
-# # Each mask has iid entries following the octanary pattern in which the entries are 
-# # distributed as b1 x b2 where 
-# # b1 is uniform over {1, -1, i, -i} (phase) 
-# # b2 is equal to 1/sqrt(2) with prob. 4/5 and sqrt(3) with prob. 1/5 (magnitude)
-
-# # Number of masks
-L = 21              
-# # Storage for L masks, each of dim n1 x n2
-Masks_1 = torch.zeros(n1,n2,L,device=DEVICE,dtype=torch.cdouble) 
-# # for ll in range(L):
-Masks_1 = torch.from_numpy(np.random.choice([1,-1,1j,-1j], (n1,n2,L), p=[1/4, 1/4, 1/4, 1/4])).to(DEVICE)
-Masks_2 = torch.from_numpy(np.random.choice([1/math.sqrt(2),math.sqrt(3)], (n1,n2,L), p=[4/5, 1/5])).to(DEVICE)
-Masks = torch.mul(Masks_1,Masks_2)
-print(f"Masks_1 is on: {Masks_1.device}")
-print(f"Masks_2 is on: {Masks_2.device}")
-print(f"Masks is on: {Masks.device}")
-
-
-
-# # Make Linear Operators
-def A(I):
-    conj = torch.conj_physical(Masks)
-    repeat = torch.tile(I, (1, L))
-    reshape = torch.reshape(repeat,[I.size(0), I.size(1),L])
-    hadamard = torch.mul(conj,reshape)
-    result = torch.fft.fft2(hadamard)
-    return result
-
-def At(Y):
-    inverse_fft =torch.fft.ifft2(Y)
-    mul = torch.mul(inverse_fft,Masks)
-    sum = torch.sum(mul,2)
-    normalized = sum * Y.size(0) * Y.size(1)
-    result = normalized
-    return result
-
-# # Prepare structure to save intermediate results 
-
-# ttimes   = [10,300];        % Iterations at which we will save info 
-# ntimes = length(ttimes)+1;   % +1 because we will save info after the initialization 
-# # Xhats    = cell(1,ntimes);
-# # for mm = 1:ntimes, Xhats{mm} = zeros(size(X));  end
-# # Times    = zeros(3,ntimes);
-ttimes = [10,30]               # Iterations at which we will save info
-ntimes = len(ttimes)+1;       # +1 because we will save info after the initialization
-
-Xhats    = torch.zeros([ntimes,n1,n2,3],dtype=torch.cdouble,device=DEVICE)
-Times    = torch.zeros(3,ntimes)
-
-# # Wirtinger flow  
-
-npower_iter = 50                         # Number of power iterations 
-T = max(ttimes)                          # Max number of iterations
-tau0 = 330                               # Time constant for step size 
-def mu(t):
-    return min(1-math.exp(-t/tau0), 0.4) # Schedule for step size 
-
-
-
-for rgb in range(0,X.shape[-1]):
-    # print(rgb)
-    print(f'Color band {rgb}')
-    x = torch.squeeze(X[:,:,rgb]); # Image x is n1 x n2 
-    Y = torch.abs(A(x))**2;        # Measured data 
-    z0 = torch.randn(n1,n2,device=DEVICE)
-    z0 = z0/torch.linalg.norm(z0,'fro') # Initial guess 
-    # start_init = time.time()
-    for tt in range(0,npower_iter):
-        z0 = At(A(z0)) 
-        z0 = z0/torch.norm(z0,'fro')
-    # end_init = time.time()
-    # elapsed_time_init = end_init - start_init
-    # print("elapsed time is:")
-    # print(str(timedelta(seconds=elapsed_time_init)),"(HH:MM:SS)")
-
-    normest = torch.sqrt(torch.sum(Y)/torch.numel(Y)) # Estimate norm to scale eigenvector  
-    z = normest * z0                   # Apply scaling
-    # print(x)
-    # print(z)
-
-    # print(z.shape)
-    # Xhats[:,:,rgb] = torch.exp(-1j*torch.angle(torch.trace(torch.matmul(x.H,z))))*z
-    # Xhats[:,:,rgb] = torch.exp(-1j*torch.angle() * z) # Initial guess after global phase adjustment
-    # Loop    
-    print(f'Done with initialization, starting loop\n') 
-
-    for t in range(1,T):
-        print(t)
-        Bz = A(z)
-        C  = torch.abs(Bz)**2-Y * Bz
-        grad = At(C)/torch.numel(C)                   # Wirtinger gradient            
-        z = z - mu(t)/normest**2 * grad        # Gradient update 
-        if t == T-2:
-            Xhats[0,:,:,rgb] = torch.exp(-1j*torch.angle(torch.trace(torch.matmul(x.H,z))))*z
-            # print(Xhats[0,:,:,rgb].shape)
-            # print((torch.exp(-1j*torch.angle(torch.trace(torch.matmul(x.H,z))))*z).shape)
-
-
-
-
-
-
-
-
-# print(Xhats.shape)
-# print(ttimes)
-# red_image = image[:,:,0]
-# daten = np.tile(red_image, (1, L))
-# print(daten.shape)
-# print(red_image.repeat(1,L).shape)
-# print(red_image.shape)
-# print(mu(4))
-# At(A(torch.from_numpy(image[:,:,0]).to(DEVICE)))
-# np_arr = X.detach().cpu().numpy()
-# image = np_arr.astype(int)
-# # print(type(X))
-# # print(X)
-# plt.imshow(image)
-# plt.axis('off')
-# plt.show()
-
-
-# io.imsave("/mnt/Data/darijani/AdaLISTA/natural_images/reconstructed.png",image)
-
-
-end = time.time()
-elapsed_time = end - start
-print("total elapsed time is:")
-print(str(timedelta(seconds=elapsed_time)),"(HH:MM:SS)")
\ No newline at end of file
diff --git a/natural_images/saved.png b/natural_images/saved.png
deleted file mode 100644
index a6eeef19ac719d1836b7b13d97b72ec0448cdc14..0000000000000000000000000000000000000000
Binary files a/natural_images/saved.png and /dev/null differ
diff --git a/natural_images/test.ipynb b/natural_images/test.ipynb
deleted file mode 100644
index e13c6aec145977a06e85cae87f20d1f5bb131743..0000000000000000000000000000000000000000
--- a/natural_images/test.ipynb
+++ /dev/null
@@ -1,133 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "code",
-   "execution_count": 6,
-   "metadata": {},
-   "outputs": [
-    {
-     "ename": "ValueError",
-     "evalue": "not enough values to unpack (expected 3, got 2)",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
-      "Cell \u001b[0;32mIn[6], line 74\u001b[0m\n\u001b[1;32m     71\u001b[0m X \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mtensor(color\u001b[39m.\u001b[39mrgb2gray(image), dtype\u001b[39m=\u001b[39mtorch\u001b[39m.\u001b[39mfloat32)\n\u001b[1;32m     73\u001b[0m \u001b[39m# Apply the Wirtinger Flow algorithm\u001b[39;00m\n\u001b[0;32m---> 74\u001b[0m Xhats, Times \u001b[39m=\u001b[39m wirtinger_flow(X)\n\u001b[1;32m     76\u001b[0m \u001b[39m# Show some results\u001b[39;00m\n\u001b[1;32m     77\u001b[0m \u001b[39m# for tt in range(len(Xhats)):\u001b[39;00m\n\u001b[1;32m     78\u001b[0m     \u001b[39m# for rgb in range(3):\u001b[39;00m\n\u001b[1;32m     79\u001b[0m         \u001b[39m# imshow(Xhats[tt][:, :, rgb].abs(), cmap='gray')\u001b[39;00m\n",
-      "Cell \u001b[0;32mIn[6], line 9\u001b[0m, in \u001b[0;36mwirtinger_flow\u001b[0;34m(X, L, npower_iter, T, tau0)\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mwirtinger_flow\u001b[39m(X, L\u001b[39m=\u001b[39m\u001b[39m21\u001b[39m, npower_iter\u001b[39m=\u001b[39m\u001b[39m50\u001b[39m, T\u001b[39m=\u001b[39m\u001b[39m300\u001b[39m, tau0\u001b[39m=\u001b[39m\u001b[39m330\u001b[39m):\n\u001b[1;32m      8\u001b[0m     \u001b[39m# X: n1 x n2 x 3 tensor representing RGB image\u001b[39;00m\n\u001b[0;32m----> 9\u001b[0m     n1, n2, _ \u001b[39m=\u001b[39m X\u001b[39m.\u001b[39mshape\n\u001b[1;32m     11\u001b[0m     \u001b[39m# Make masks and linear sampling operators\u001b[39;00m\n\u001b[1;32m     12\u001b[0m     Masks \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mzeros(n1, n2, L, dtype\u001b[39m=\u001b[39mtorch\u001b[39m.\u001b[39mcfloat)\n",
-      "\u001b[0;31mValueError\u001b[0m: not enough values to unpack (expected 3, got 2)"
-     ]
-    }
-   ],
-   "source": [
-    "import skimage.io\n",
-    "import torch\n",
-    "import torch.fft as fft\n",
-    "from skimage import color\n",
-    "\n",
-    "# Implementation of the Wirtinger Flow (WF) algorithm\n",
-    "def wirtinger_flow(X, L=21, npower_iter=50, T=300, tau0=330):\n",
-    "    # X: n1 x n2 x 3 tensor representing RGB image\n",
-    "    n1, n2, _ = X.shape\n",
-    "\n",
-    "    # Make masks and linear sampling operators\n",
-    "    Masks = torch.zeros(n1, n2, L, dtype=torch.cfloat)\n",
-    "    for ll in range(L):\n",
-    "        phase = torch.choice([1j, -1j, 1, -1], (n1, n2))\n",
-    "        magnitude = torch.where(torch.rand(n1, n2) <= 0.2, torch.sqrt(torch.tensor(3.0)), torch.tensor(1 / torch.sqrt(2)))\n",
-    "        Masks[:, :, ll] = phase * magnitude\n",
-    "\n",
-    "    A = lambda I: fft.fft2(torch.conj(Masks) * I.view(n1, n2, 1).repeat(1, 1, L))\n",
-    "    At = lambda Y: torch.sum(Masks * fft.ifft2(Y), dim=2) * n1 * n2\n",
-    "\n",
-    "    # Prepare structure to save intermediate results\n",
-    "    ttimes = [150, 300]\n",
-    "    ntimes = len(ttimes) + 1\n",
-    "    Xhats = [torch.zeros(n1, n2, 3) for _ in range(ntimes)]\n",
-    "    Times = torch.zeros(3, ntimes)\n",
-    "\n",
-    "    for rgb in range(3):\n",
-    "        print(f'Color band {rgb+1}')\n",
-    "        x = X[:, :, rgb]\n",
-    "        Y = torch.abs(A(x))**2\n",
-    "\n",
-    "        # Initialization\n",
-    "        z0 = torch.randn(n1, n2, dtype=torch.cfloat)\n",
-    "        z0 = z0 / torch.norm(z0, 'fro')\n",
-    "        for tt in range(npower_iter):\n",
-    "            z0 = At(A(z0))\n",
-    "            z0 = z0 / torch.norm(z0, 'fro')\n",
-    "        Times[rgb, 0] = 0.0\n",
-    "\n",
-    "        normest = torch.sqrt(torch.sum(Y) / Y.numel())\n",
-    "        z = normest * z0\n",
-    "        Xhats[0][:, :, rgb] = torch.exp(-1j * torch.angle(torch.trace(x.T @ z))) * z\n",
-    "\n",
-    "        # Loop\n",
-    "        print('Done with initialization, starting loop')\n",
-    "        for t in range(1, T + 1):\n",
-    "            Bz = A(z)\n",
-    "            C = (torch.abs(Bz)**2 - Y) * Bz\n",
-    "            grad = At(C) / C.numel()  # Wirtinger gradient\n",
-    "            z = z - mu(t, tau0) / normest**2 * grad  # Gradient update\n",
-    "\n",
-    "            if t in ttimes:\n",
-    "                ind = ttimes.index(t) + 1\n",
-    "                Xhats[ind][:, :, rgb] = torch.exp(-1j * torch.angle(torch.trace(x.T @ z))) * z\n",
-    "                Times[rgb, ind] = 0.0\n",
-    "\n",
-    "    print('All done!')\n",
-    "    return Xhats, Times\n",
-    "\n",
-    "def mu(t, tau0):\n",
-    "    return min(1 - torch.exp(-t / tau0), 0.4)\n",
-    "\n",
-    "if __name__ == \"__main__\":\n",
-    "    # Read the RGB image (you need to provide the path to the image)\n",
-    "    image_path = '/mnt/Data/darijani/AdaLISTA/natural_images/galaxy.jpg'\n",
-    "    image = skimage.io.imread(image_path)\n",
-    "# % Below X is n1 x n2 x 3; i.e. we have three n1 x n2 images, one for each of the 3 color channels  \n",
-    "    n1, n2 = image.shape[:2]\n",
-    "    # namestr = 'stanford'\n",
-    "    # stanstr = 'jpg'\n",
-    "    X = torch.tensor(color.rgb2gray(image), dtype=torch.float32)\n",
-    "\n",
-    "    # Apply the Wirtinger Flow algorithm\n",
-    "    Xhats, Times = wirtinger_flow(X)\n",
-    "\n",
-    "    # Show some results\n",
-    "    # for tt in range(len(Xhats)):\n",
-    "        # for rgb in range(3):\n",
-    "            # imshow(Xhats[tt][:, :, rgb].abs(), cmap='gray')\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ],
- "metadata": {
-  "kernelspec": {
-   "display_name": "torch",
-   "language": "python",
-   "name": "python3"
-  },
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.10.9"
-  },
-  "orig_nbformat": 4
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/natural_images/test.py b/natural_images/test.py
deleted file mode 100644
index 5c76730296877ef8dac1f2ac81c3ef817aac7aba..0000000000000000000000000000000000000000
--- a/natural_images/test.py
+++ /dev/null
@@ -1,111 +0,0 @@
-import torch
-import torch.fft as fft
-import math
-
-# Function to create random masks
-def create_masks(n1, n2, L):
-    masks = torch.randint(0, 4, (n1, n2, L))
-    masks = torch.where(masks == 0, torch.tensor(1j), masks)
-    masks = torch.where(masks == 1, torch.tensor(-1j), masks)
-    masks = torch.where(masks == 2, torch.tensor(1), masks)
-    masks = torch.where(masks == 3, torch.tensor(-1), masks)
-    temp = torch.rand(n1, n2, L)
-    masks = masks * torch.where(temp <= 0.2, torch.tensor(math.sqrt(3)), torch.tensor(1 / math.sqrt(2)))
-    return masks
-
-# Function to compute the linear operator A
-def linear_operator_A(I, masks):
-    masks_conj = torch.conj(masks)
-    repeated_I = I.unsqueeze(2).repeat(1, 1, masks.shape[2])
-    F_I = fft.fft2(repeated_I)
-    A_I = F_I * masks_conj
-    return A_I
-
-# Function to compute the adjoint linear operator A^T
-def linear_operator_At(Y, masks):
-    masks_Y = masks * fft.ifft2(Y)
-    return masks_Y.sum(dim=2) * Y.shape[1] * Y.shape[2]
-
-# Wirtinger flow implementation
-def wirtinger_flow(X, L, npower_iter, ttimes, tau0):
-    n1, n2, _ = X.shape
-    times = [150, 300]
-    ntimes = len(times) + 1
-    Xhats = [torch.zeros_like(X,dtype=torch.cdouble) for _ in range(ntimes)]
-    for mm in range(ntimes):
-        Xhats[mm] = X.clone()
-
-    Times = torch.zeros(3, ntimes)
-
-    for rgb in range(3):
-        print(f'Color band {rgb + 1}')
-        x = X[:, :, rgb]  # Image x is n1 x n2
-        Y = torch.abs(linear_operator_A(x, masks)) ** 2  # Measured data
-
-        # Initialization
-        z0 = torch.randn(n1, n2).to(x.device)
-        z0 = z0 / torch.norm(z0, 'fro')
-        for tt in range(npower_iter):
-            z0 = linear_operator_At(linear_operator_A(z0, masks), masks)
-            z0 = z0 / torch.norm(z0, 'fro')
-        Times[rgb, 0] = 0
-
-        normest = torch.sqrt(Y.sum() / Y.numel())  # Estimate norm to scale eigenvector
-        z = normest * z0  # Apply scaling
-        Xhats[0][:, :, rgb] = torch.exp(-1j * torch.angle(torch.trace(x.t() @ z))) * z  # Initial guess after global phase adjustment
-
-        # Loop
-        print('Done with initialization, starting loop')
-        for t in range(T):
-            Bz = linear_operator_A(z, masks)
-            C = (torch.abs(Bz) ** 2 - Y) * Bz
-            grad = linear_operator_At(C / C.numel(), masks)  # Wirtinger gradient
-            z = z - mu(t) / (normest ** 2) * grad  # Gradient update
-
-            ind = ttimes.index(t) if t in ttimes else -1
-            if ind >= 0:
-                Xhats[ind + 1][:, :, rgb] = torch.exp(-1j * torch.angle(torch.trace(x.t() @ z))) * z
-                Times[rgb, ind + 1] = 0
-
-    print('All done!')
-    return Xhats, Times
-
-# Example usage
-if __name__ == '__main__':
-    import numpy as np
-    from PIL import Image
-
-    # Read Image
-    namestr = 'stanford'
-    stanstr = 'jpg'
-    img_path = f'{namestr}.{stanstr}'
-    img = Image.open(img_path)
-    X = torch.tensor(np.array(img)).float() / 255.0
-    n1, n2 = X.shape[:2]
-
-    # Define parameters
-    L = 21
-    npower_iter = 50
-    ttimes = [150, 300]
-    tau0 = 330
-
-    # Create masks
-    masks = create_masks(n1, n2, L)
-
-    # Perform Wirtinger Flow
-    Xhats, Times = wirtinger_flow(X, L, npower_iter, ttimes, tau0)
-
-    # Show some results
-    iter = [0] + ttimes
-    Relerrs = []
-    for mm in range(len(Xhats)):
-        mean_time = torch.mean(Times[:, mm])
-        Relerr = torch.norm(Xhats[mm].flatten() - X.flatten()) / torch.norm(X.flatten())
-        Relerrs.append(Relerr.item())
-        print(f"Mean running times after {iter[mm]} iterations: {mean_time:.1f}")
-        print(f"Relative error after {iter[mm]} iterations: {Relerr:.5f}\n")
-
-    # Display the resulting images
-    for tt in range(len(Xhats)):
-        img = Image.fromarray(np.uint8(Xhats[tt].numpy() * 255))
-        img.show()