diff --git a/db.0 b/db.0
index f4f28fdd36ec4e301790c2f5ce2773de7d6f5518..b3657d0f04bd6e52513fa73214e54176c509b647 100644
Binary files a/db.0 and b/db.0 differ
diff --git a/db.00 b/db.00
new file mode 100644
index 0000000000000000000000000000000000000000..6e67b120522c43b83db031a2f1c94ca5dee6b1fe
Binary files /dev/null and b/db.00 differ
diff --git a/generating.py b/generating.py
index 5234ce39f5d9d9ce44ddb6361661db2b9186a0fe..43190ac570c7e4d73397282f1e4aff381836b21b 100644
--- a/generating.py
+++ b/generating.py
@@ -1,31 +1,35 @@
 import torch
-import torch.utils.data as Data
-from scipy.linalg import eigvalsh
-import useful_utils
-import numpy as np
+
+# import torch.utils.data as Data
+# from scipy.linalg import eigvalsh
+# import numpy as np
+import torch
 import math
-import params
 
 
-def gen_x_Amatrix_y(n,m, N, cplx_flag, DEVICE):
+def gen_x_Amatrix_y(n, m, N, cplx_flag, cuda_flag):
     """create x, y, and Amatrix"""
-    x_1 = torch.randn((N,n), dtype=torch.double,device=DEVICE)
-    x_2 = torch.randn((N,n), dtype=torch.double,device=DEVICE)*1j*cplx_flag
-    x= x_1 + x_2
-    x_1_val = torch.randn((N,n), dtype=torch.double,device=DEVICE)
-    x_2_val = torch.randn((N,n), dtype=torch.double,device=DEVICE)*1j*cplx_flag
+    if torch.cuda.is_available() & cuda_flag:
+        DEVICE = "cuda"
+    else:
+        DEVICE = "cpu" 
+    x_1 = torch.randn((N, n), dtype=torch.double, device=DEVICE)
+    x_2 = torch.randn((N, n), dtype=torch.double, device=DEVICE) * 1j * cplx_flag
+    x = x_1 + x_2
+    x_1_val = torch.randn((N, n), dtype=torch.double, device=DEVICE)
+    x_2_val = torch.randn((N, n), dtype=torch.double, device=DEVICE) * 1j * cplx_flag
     x_val = x_1_val + x_2_val
-    Amatrix_1 = torch.randn((m,n), dtype=torch.double,device=DEVICE)
-    Amatrix_2 = torch.randn((m,n), dtype=torch.double,device=DEVICE)*1j*cplx_flag
-    scale = math.sqrt(2)**cplx_flag
-    Amatrix = (Amatrix_1 + Amatrix_2)/scale
-    y_trans = torch.abs(torch.linalg.matmul(Amatrix,x.T),device=DEVICE)  # y_i=|a_i x|
-    y_trans_val = torch.abs(torch.linalg.matmul(Amatrix,x_val.T),device=DEVICE)  # y_i=|a_i x|
-    return x, x_val, Amatrix, y_trans.T, y_trans_val.T
+    Amatrix_1 = torch.randn((n, m), dtype=torch.double, device=DEVICE)
+    Amatrix_2 = torch.randn((n, m), dtype=torch.double, device=DEVICE) * 1j * cplx_flag
+    scale = math.sqrt(2) ** cplx_flag
+    Amatrix = (Amatrix_1 + Amatrix_2) / scale
+    # y =|A(x)| Element-wise absolute value of a linear operator
+    y = torch.abs(torch.linalg.matmul(x, Amatrix)) 
+    # y =|A(x)| Element-wise absolute value of a linear operator
+    y_val = torch.abs(torch.linalg.matmul(x_val, Amatrix))
+    return x, x_val, Amatrix, y, y_val
     # torch.save(x, 'x.pt')
     # torch.save(x_val, 'x_val.pt')
     # torch.save(Amatrix, 'Amatrix.pt')
-    # torch.save(y_trans.T, 'y.pt')
-    # torch.save(y_trans_val.T, 'y_val.pt')
-
-# gen_x_Amatrix_y(params.n,params.m,params.N,params.cplx_flag)
\ No newline at end of file
+    # torch.save(y, 'y.pt')
+    # torch.save(y_val, 'y_val.pt')
diff --git a/main.py b/main.py
index 70ca37c8537f728b7a2eaab7784619d7e6b1acc7..122f90ffe0d072f731a7b35a59ff4e41c11476dc 100644
--- a/main.py
+++ b/main.py
@@ -10,6 +10,7 @@ we here use a small subset of it.
 import generating
 
 import numpy as np
+import math
 import os
 import train_URWF
 import optuna
@@ -45,9 +46,44 @@ def objective(trial):
         DEVICE = "cuda"
     else:
         DEVICE = "cpu" 
+    
+    print(DEVICE)
     SCENARIO = trial.suggest_categorical("SCENARIO", ["scalar", "vector", "matrix", "vector-matrix", "tensor", "SPD-matrix"])
     LR = trial.suggest_float('LR', 1e-4, 1e-2)
-    x,x_val, Amatrix, y, y_val  = generating.gen_x_Amatrix_y(n,m,N,cplx_flag,DEVICE)
+    x,x_val, Amatrix, y, y_val  = generating.gen_x_Amatrix_y(n,m,N,cplx_flag, DEVICE)
+
+    def wf_solve(x,Amatrix,y, DEVICE):
+        def A(Amatrix,x):
+            result = torch.linalg.matmul(x, Amatrix)
+            return result
+        def Ah(Amatrix,x):
+            result = torch.linalg.matmul(x, Amatrix.H)
+            return result
+        def distance(z,x):
+            z = torch.flatten(z)
+            x = torch.flatten(x)
+            error = torch.linalg.norm(x - torch.exp(-1j*torch.angle(torch.matmul(x,z))) * z)/torch.linalg.norm(x)
+            return error
+        T = 800
+        tau0 = 330
+        npower_iter = 30
+        Relerrs = np.zeros(T+1)
+        z0 = torch.randn((N, n), dtype=torch.cdouble,device=DEVICE)
+        z0 = z0/torch.linalg.norm(z0)
+        for tt in range(npower_iter):
+            z0 = Ah(Amatrix,torch.multiply(y, (A(Amatrix,z0))))
+            z0 = z0/torch.linalg.norm(z0)
+        normest = math.sqrt(torch.sum(y))/m
+        z0 = normest * z0
+        z = z0
+        print(distance(x,z))
+
+
+
+        return 0
+    wf_solve(x,Amatrix,y,DEVICE)
+    print("datenkar")
+    
  
     # err_URWF_knot, err_URWF_inf = train_URWF.train_URWF(x, x_val, Amatrix, y, y_val, SCENARIO,LR,DEVICE)
     # return np.average(err_URWF_inf)
diff --git a/natural_images/irwf_reconstructed/irwf_reconstructed_-1_sat_phone.png b/natural_images/irwf_reconstructed/irwf_reconstructed_-1_sat_phone.png
new file mode 100644
index 0000000000000000000000000000000000000000..bf7fc7137d746ffa730c5bef9213756065b041fb
Binary files /dev/null and b/natural_images/irwf_reconstructed/irwf_reconstructed_-1_sat_phone.png differ
diff --git a/natural_images/natural_sat_phone_irwf.ipynb b/natural_images/natural_sat_phone_irwf.ipynb
index 1009b7b7dab62800a2bbc7463207ede7f893b997..4edb34483cce22bd55b42d67ee9f221a26926f50 100644
--- a/natural_images/natural_sat_phone_irwf.ipynb
+++ b/natural_images/natural_sat_phone_irwf.ipynb
@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 1,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -30,7 +30,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -46,7 +46,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 3,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -100,7 +100,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [
     {
@@ -159,7 +159,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 5,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -188,7 +188,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 6,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -215,19 +215,19 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 7,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "tensor(19.4844, device='cuda:0', dtype=torch.float64)\n",
-      "tensor(19.3103, device='cuda:0', dtype=torch.float64)\n",
-      "tensor(19.4632, device='cuda:0', dtype=torch.float64)\n",
-      "tensor(1.0007, device='cuda:0', dtype=torch.float64)\n",
-      "tensor(1.0007, device='cuda:0', dtype=torch.float64)\n",
-      "tensor(1.0008, device='cuda:0', dtype=torch.float64)\n"
+      "tensor(19.4858, device='cuda:0', dtype=torch.float64)\n",
+      "tensor(19.3127, device='cuda:0', dtype=torch.float64)\n",
+      "tensor(19.4663, device='cuda:0', dtype=torch.float64)\n",
+      "tensor(1.0009, device='cuda:0', dtype=torch.float64)\n",
+      "tensor(1.0009, device='cuda:0', dtype=torch.float64)\n",
+      "tensor(1.0009, device='cuda:0', dtype=torch.float64)\n"
      ]
     }
    ],
@@ -263,7 +263,12 @@
     "X_recons_b = distance_update(x_b,z_b)\n",
     "print(image_err(X_recons_r,x_r))\n",
     "print(image_err(X_recons_g,x_g))\n",
-    "print(image_err(X_recons_b,x_b))\n"
+    "print(image_err(X_recons_b,x_b))\n",
+    "X_recons[:,:,0] = distance_update(x_r,z_r)\n",
+    "X_recons[:,:,1] = distance_update(x_g,z_g)\n",
+    "X_recons[:,:,2] = distance_update(x_b,z_b)\n",
+    "save_image(X_recons,-1,file_name)\n",
+    "\n"
    ]
   },
   {
@@ -589,7 +594,7 @@
     "    ################################\n",
     "    X_recons[:,:,0] = distance_update(x_r,z_r)\n",
     "    err[t,0] = image_err(X_recons[:,:,0],X[:,:,0])\n",
-    "    print(\"iteration:\",t,\"error in the red channel:\",image_err(X_recons[:,:,0],X[:,:,0]))\n",
+    "    # print(\"iteration:\",t,\"error in the red channel:\",image_err(X_recons[:,:,0],X[:,:,0]))\n",
     "    yz_r = A(z_r)\n",
     "    yz_abs_r = torch.abs(yz_r)\n",
     "    first_divide_r = torch.divide(yz_r,yz_abs_r)\n",
@@ -602,7 +607,7 @@
     "    ################################\n",
     "    X_recons[:,:,1] = distance_update(x_g,z_g)\n",
     "    err[t,1] = image_err(X_recons[:,:,1],X[:,:,1])\n",
-    "    print(\"iteration:\",t,\"error in the green channel:\",image_err(X_recons[:,:,1],X[:,:,1]))\n",
+    "    # print(\"iteration:\",t,\"error in the green channel:\",image_err(X_recons[:,:,1],X[:,:,1]))\n",
     "    yz_g = A(z_g)\n",
     "    yz_abs_g = torch.abs(yz_g)\n",
     "    first_divide_g = torch.divide(yz_g,yz_abs_g)\n",
@@ -615,7 +620,7 @@
     "    ################################\n",
     "    X_recons[:,:,2] = distance_update(x_b,z_b)\n",
     "    err[t,2] = image_err(X_recons[:,:,2],X[:,:,2])\n",
-    "    print(\"iteration:\",t,\"error in the blue channel:\",image_err(X_recons[:,:,2],X[:,:,2]))\n",
+    "    # print(\"iteration:\",t,\"error in the blue channel:\",image_err(X_recons[:,:,2],X[:,:,2]))\n",
     "    yz_b = A(z_b)\n",
     "    yz_abs_b = torch.abs(yz_b)\n",
     "    first_divide_b = torch.divide(yz_b,yz_abs_b)\n",
diff --git a/natural_images/natural_sat_phone_rwf.ipynb b/natural_images/natural_sat_phone_rwf.ipynb
index 2ed4923407e4245d45674dc25f0032c9ec6659a5..9f37b9aa1c5da322715923f473ab5513c2e7b9b3 100644
--- a/natural_images/natural_sat_phone_rwf.ipynb
+++ b/natural_images/natural_sat_phone_rwf.ipynb
@@ -222,12 +222,12 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "tensor(19.4845, device='cuda:0', dtype=torch.float64)\n",
-      "tensor(19.3110, device='cuda:0', dtype=torch.float64)\n",
-      "tensor(19.4643, device='cuda:0', dtype=torch.float64)\n",
-      "tensor(1.0007, device='cuda:0', dtype=torch.float64)\n",
+      "tensor(19.4882, device='cuda:0', dtype=torch.float64)\n",
+      "tensor(19.3146, device='cuda:0', dtype=torch.float64)\n",
+      "tensor(19.4678, device='cuda:0', dtype=torch.float64)\n",
+      "tensor(1.0009, device='cuda:0', dtype=torch.float64)\n",
       "tensor(1.0010, device='cuda:0', dtype=torch.float64)\n",
-      "tensor(1.0008, device='cuda:0', dtype=torch.float64)\n"
+      "tensor(1.0010, device='cuda:0', dtype=torch.float64)\n"
      ]
     }
    ],
@@ -263,7 +263,12 @@
     "X_recons_b = distance_update(x_b,z_b)\n",
     "print(image_err(X_recons_r,x_r))\n",
     "print(image_err(X_recons_g,x_g))\n",
-    "print(image_err(X_recons_b,x_b))\n"
+    "print(image_err(X_recons_b,x_b))\n",
+    "X_recons[:,:,0] = distance_update(x_r,z_r)\n",
+    "X_recons[:,:,1] = distance_update(x_g,z_g)\n",
+    "X_recons[:,:,2] = distance_update(x_b,z_b)\n",
+    "save_image(X_recons,-1,file_name)\n",
+    "\n"
    ]
   },
   {
diff --git a/natural_images/natural_sat_phone_twf.ipynb b/natural_images/natural_sat_phone_twf.ipynb
index d2af014f7f8e97cb3c8565cc31793663ef46fee0..2c85cbcf4483caf97c2f2344fbd79af3a483706c 100644
--- a/natural_images/natural_sat_phone_twf.ipynb
+++ b/natural_images/natural_sat_phone_twf.ipynb
@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 77,
+   "execution_count": 1,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -35,7 +35,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 78,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -51,7 +51,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 79,
+   "execution_count": 3,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -107,7 +107,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 80,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [
     {
@@ -172,7 +172,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 81,
+   "execution_count": 5,
    "metadata": {},
    "outputs": [
     {
@@ -190,7 +190,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 82,
+   "execution_count": 6,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -219,7 +219,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 83,
+   "execution_count": 7,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -246,7 +246,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 84,
+   "execution_count": 8,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -260,19 +260,19 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 85,
+   "execution_count": 10,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "tensor(428.3096, device='cuda:0', dtype=torch.float64)\n",
-      "tensor(420.7117, device='cuda:0', dtype=torch.float64)\n",
-      "tensor(427.4226, device='cuda:0', dtype=torch.float64)\n",
-      "tensor(0.9130, device='cuda:0', dtype=torch.float64)\n",
-      "tensor(0.6938, device='cuda:0', dtype=torch.float64)\n",
-      "tensor(0.7306, device='cuda:0', dtype=torch.float64)\n"
+      "tensor(428.4117, device='cuda:0', dtype=torch.float64)\n",
+      "tensor(420.8286, device='cuda:0', dtype=torch.float64)\n",
+      "tensor(427.5470, device='cuda:0', dtype=torch.float64)\n",
+      "tensor(0.5845, device='cuda:0', dtype=torch.float64)\n",
+      "tensor(0.6381, device='cuda:0', dtype=torch.float64)\n",
+      "tensor(0.9698, device='cuda:0', dtype=torch.float64)\n"
      ]
     }
    ],
@@ -310,7 +310,10 @@
     "print(image_err(X_recons_g,x_g))\n",
     "print(image_err(X_recons_b,x_b))\n",
     "\n",
-    "\n"
+    "X_recons[:,:,0] = distance_update(x_r,z_r)\n",
+    "X_recons[:,:,1] = distance_update(x_g,z_g)\n",
+    "X_recons[:,:,2] = distance_update(x_b,z_b)\n",
+    "save_image(X_recons,-1,file_name)"
    ]
   },
   {
@@ -1980,7 +1983,7 @@
     "    z_r_f = torch.flatten(z_r)\n",
     "    error = torch.linalg.norm(x_r_f - torch.exp(-1j*torch.angle(torch.matmul(x_r_f,z_r_f))) * z_r_f)/torch.linalg.norm(x_r_f)\n",
     "    return error\n",
-    "print(distance(X_recons,X))\n"
+    "# print(distance(X_recons,X))\n"
    ]
   },
   {
diff --git a/natural_images/natural_sat_phone_wf.ipynb b/natural_images/natural_sat_phone_wf.ipynb
index d0e239b19de51b686d8d2c5fe968057555796b8e..653dda7980c6f4273c2f938541042d5eee5a3def 100644
--- a/natural_images/natural_sat_phone_wf.ipynb
+++ b/natural_images/natural_sat_phone_wf.ipynb
@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 1,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -35,7 +35,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -51,7 +51,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 3,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -107,7 +107,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [
     {
@@ -124,8 +124,7 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "torch.Size([563, 1000, 3])\n",
-      "tensor(737.2259, device='cuda:0', dtype=torch.float64)\n"
+      "torch.Size([563, 1000, 3])\n"
      ]
     }
    ],
@@ -163,7 +162,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 5,
    "metadata": {},
    "outputs": [
     {
@@ -181,7 +180,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 6,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -210,7 +209,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 7,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -232,12 +231,14 @@
     "z0_b = z0_b/torch.linalg.norm(z0_b,'fro') \n",
     "############################################\n",
     "X_recons = torch.zeros(X.shape,dtype=torch.cdouble,device=DEVICE)\n",
+    "\n",
+    "\n",
     "\n"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": 8,
    "metadata": {},
    "outputs": [
     {
@@ -255,16 +256,16 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 9,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "tensor(428.4212, device='cuda:0', dtype=torch.float64)\n",
-      "tensor(420.8395, device='cuda:0', dtype=torch.float64)\n",
-      "tensor(427.5393, device='cuda:0', dtype=torch.float64)\n"
+      "tensor(428.5580, device='cuda:0', dtype=torch.float64)\n",
+      "tensor(420.9625, device='cuda:0', dtype=torch.float64)\n",
+      "tensor(427.6684, device='cuda:0', dtype=torch.float64)\n"
      ]
     }
    ],
@@ -289,6 +290,10 @@
     "print(normest_g)\n",
     "print(normest_b)\n",
     "\n",
+    "X_recons[:,:,0] = distance_update(x_r,z_r)\n",
+    "X_recons[:,:,1] = distance_update(x_g,z_g)\n",
+    "X_recons[:,:,2] = distance_update(x_b,z_b)\n",
+    "save_image(X_recons,-1,file_name)\n",
     "\n"
    ]
   },
diff --git a/natural_images/rwf_reconstructed/rwf_reconstructed_-1_sat_phone.png b/natural_images/rwf_reconstructed/rwf_reconstructed_-1_sat_phone.png
new file mode 100644
index 0000000000000000000000000000000000000000..0d8eefdb51fb18469235cc466c99b34dbb60e2eb
Binary files /dev/null and b/natural_images/rwf_reconstructed/rwf_reconstructed_-1_sat_phone.png differ
diff --git a/natural_images/twf_reconstructed/twf_reconstructed_-1_sat_phone.png b/natural_images/twf_reconstructed/twf_reconstructed_-1_sat_phone.png
new file mode 100644
index 0000000000000000000000000000000000000000..149581a835535c20872d014eae5d5695824fc3d8
Binary files /dev/null and b/natural_images/twf_reconstructed/twf_reconstructed_-1_sat_phone.png differ
diff --git a/natural_images/wf_reconstructed/wf_reconstructed_-1_sat_phone.png b/natural_images/wf_reconstructed/wf_reconstructed_-1_sat_phone.png
new file mode 100644
index 0000000000000000000000000000000000000000..9c678f8d639eb9ada28f97511d801dd0469541d1
Binary files /dev/null and b/natural_images/wf_reconstructed/wf_reconstructed_-1_sat_phone.png differ
diff --git a/params.py b/params.py
deleted file mode 100644
index 49f7049b140c3b117bb34bc61bef1ecb336bdefa..0000000000000000000000000000000000000000
--- a/params.py
+++ /dev/null
@@ -1,67 +0,0 @@
-import argparse
-import torch
-import numpy as np
-# n = 64                      # signal dimension
-# alpha = 10
-# m = alpha * n               # number of measurements
-# # m = 2
-# cplx_flag = 1               # real: cplx_flag = 0;  complex: cplx_flag = 1;
-# T = 80                      # number of iterations
-# npower_iter = 30            # number of power iterations
-# N_train= 100                # Number of training samples
-# EPOCHS = 1
-# scenario = 4 
-# LR = 1e-3
-
-# N = N_train                 # Number of training samples
-# mu = 0.8+0.4*cplx_flag      # suggested step for the Wirtinger Flow
-# cuda_opt = 1
-# if torch.cuda.is_available() & cuda_opt:
-#     DEVICE = "cuda"
-# else:
-#     DEVICE = "cpu"              
-
-# batch = n
-
-
-# scalar = True
-# vector = True
-# matrix = True
-# tensor = True
-# if scenario == 0:
-    # scalar = True 
-    # vector = False
-    # matrix = False
-# if scenario == 1:
-    # scalar = False 
-    # vector = True
-    # matrix = False
-# if scenario == 2:
-    # scalar = False
-    # vector = True
-    # matrix = True
-# if scenario == 3:
-    # scalar = False
-    # vector = False
-    # matrix = True
-
-
-
-
-# print(Tensor)
-# parser = argparse.ArgumentParser()
-# parser.add_argument("-s", help="Scenario", type=int)
-# parser.add_argument("-lr", help="Learning Rate", type=float)
-# parser.add_argument("-e", help="epochs", type=int)
-# parser.add_argument("-n", help="SNR", type=int)
-# parser.add_argument("-c", help="Scenario", type=int)
-# parser.add_argument("-tstart", help="T start (Number of unfoldings)", type=int)
-# parser.add_argument("-tend", help="T end (Number of unfoldings)", type=int)
-# parser.add_argument("-tstep", help="T step (Number of unfoldings)", type=int)
-# parser.add_argument("-ntrain", help="N Train", type=int)
-# parser.add_argument("-sigsnr", help="Signal SNR", type=int)
-# parser.add_argument("-epochs", help="Number of epochs", type=int)
-# args = parser.parse_args()
-# scenario = args.s
-# LR = args.lr   
-# EPOCHS = args.e  
\ No newline at end of file
diff --git a/reconstructed b/reconstructed
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/reconstructed.png b/reconstructed.png
deleted file mode 100644
index 2ad830f8682a7d1c5c2fee06c801e34613fff9a2..0000000000000000000000000000000000000000
Binary files a/reconstructed.png and /dev/null differ
diff --git a/without_hyper.py b/without_hyper.py
index a63b7ed15c226a44f1457f44df66dde89fe0b6bd..b5108b80f8ec24b89fc491e40697b9bcadc00d0e 100644
--- a/without_hyper.py
+++ b/without_hyper.py
@@ -1,52 +1,92 @@
-import numpy as np
-import params
 import torch
-# import torch
+import numpy as np
+import math
 import generating
-import matplotlib.pyplot as plt
-import train_URWF
-import train_UIRWF
-import visualize
-import argparse
-import time
-# import pathlib
-import datetime
-from datetime import timedelta
-
-start = time.time()
-print("n_samples = ", params.N_train)
-n_val = int(0.1 * params.N_train)
-print("n_val = ", n_val)
-# print("model = UIRWF")
-print("scenario = ", params.scenario)
-print("ML_LR = ", params.LR)
-print("EPOCHS = ", params.EPOCHS)
-print("device = ",params.DEVICE)
-generating.gen_x_Amatrix_y(params.n,params.m,params.N,params.cplx_flag)
-err_URWF_knot, err_URWF_inf = train_URWF.train_URWF()
-print(np.average(err_URWF_inf))
-# print(err_URWF_knot)
-# print(err_URWF_inf)
-# err_UIRWF_knot, err_UIRWF_inf = train_UIRWF.train_UIRWF()
-visualize.plot()
-
-plt.rcParams['text.usetex'] = True
-plt.title("err vs epoch index")
-plt.xlabel("$epochs$")
-plt.ylabel(r'$\min \| z_{i}e^{-i\phi} - x_i \|_F$', fontsize=14, color='k')
-plt.figure(2)
-plt.semilogy(np.arange(0, params.N_train), err_URWF_knot, color ="red",label="URWF on untrained network")
-plt.semilogy(np.arange(0, params.N_train), err_URWF_inf, color ="blue",label="URWF on trained network")
-# plt.semilogy(np.arange(0, params.N_train), err_UIRWF_knot, color ="orange",label="UIRWF on untrained network")
-# plt.semilogy(np.arange(0, params.N_train), err_UIRWF_inf, color ="green",label="UIRWF on trained network")
-plt.title(r'URWF vs UIRWF on (un)trained networks',fontsize=10)
-plt.legend()
-plt.savefig('Scenario = '+str(params.scenario)+', Layers = '+str(params.T)+', Epochs = '+str(params.EPOCHS)+'_ LR = '+str(params.LR)+'_(un)trained.pdf')
-# plt.show()
-
-
-
-end = time.time()
-elapsed_time = end - start
-print("elapsed time is:")
-print(str(timedelta(seconds=elapsed_time)),"(HH:MM:SS)")
+
+
+n = 64
+m = 10 * n
+N = 2
+cplx_flag = 1
+cuda_flag = 1
+x, x_val, Amatrix, y, y_val = generating.gen_x_Amatrix_y(n, m, N, cplx_flag, cuda_flag)
+
+
+def wf_solve(x, Amatrix, y, cuda_flag):
+    if torch.cuda.is_available() & cuda_flag:
+        DEVICE = "cuda"
+    else:
+        DEVICE = "cpu"
+
+    def A(Amatrix, x):
+        result = torch.linalg.matmul(x, Amatrix)
+        return result
+
+    def Ah(Amatrix, x):
+        result = torch.linalg.matmul(x, Amatrix.H)
+        return result
+
+    def distance(z, x):
+        z = torch.flatten(z)
+        x = torch.flatten(x)
+        error = torch.linalg.norm(
+            x - torch.exp(-1j * torch.angle(torch.dot(torch.conj_physical(x), z))) * z
+        ) / torch.linalg.norm(x)
+        # error = torch.linalg.norm(
+            # x - torch.exp(-1j * torch.exp(-1j*torch.angle(torch.trace(x.H*z)))) * z
+        # ) / torch.linalg.norm(x)
+        
+        return error
+
+    T = 800
+    tau0 = 330
+    npower_iter = 300
+    # Relerrs = np.zeros(T + 1)
+    z0 = torch.randn(x.shape, dtype=torch.cdouble, device=DEVICE)
+    z0 = z0 / torch.linalg.norm(z0)
+    for tt in range(npower_iter):
+        z0 = Ah(Amatrix,torch.multiply(y, (A(Amatrix,z0))))
+        z0 = z0/torch.linalg.norm(z0)
+    normest = math.sqrt(torch.sum(y))/m
+    z0 = normest * z0
+    z = z0
+
+    print(distance(z,x))
+    normest = math.sqrt(torch.sum(y**2)/m)
+    for tt in range(T+1):
+        yz = A(Amatrix , z)
+        grad  = 1/m * Ah( Amatrix, torch.multiply(abs(yz)**2-y**2, yz))
+        z = z - (min(1-math.exp(-tt/tau0), 0.2))/normest**2 * grad    
+        print("iteration:",tt,distance(z, x))
+        print("iteration:",tt,distance(z, x*1j))
+    print(distance(z, x))
+    # zero = torch.zeros(x.shape,dtype=torch.cdouble,device=DEVICE)
+    # print(distance(zero, x))
+    return z
+
+z = wf_solve(x, Amatrix, y, cuda_flag)
+
+
+def distance(z, x):
+    z = torch.flatten(z)
+    x = torch.flatten(x)
+    error = torch.linalg.norm(
+        x - torch.exp(-1j * torch.angle(torch.dot(torch.conj_physical(x), z))) * z
+    ) / torch.linalg.norm(x)
+    # error = torch.linalg.norm(
+        # x - torch.exp(-1j * torch.exp(-1j*torch.angle(torch.trace(x.H*z)))) * z
+    # ) / torch.linalg.norm(x)
+    
+    return error
+
+print(distance(x,x))
+print(distance(z,x))
+# if torch.cuda.is_available() & cuda_flag:
+#     DEVICE = "cuda"
+# else:
+#     DEVICE = "cpu"
+# z0 = torch.randn( n, dtype=torch.cdouble, device=DEVICE)
+# z0 = z0 / torch.linalg.norm(z0)
+# z = z0
+# zero = torch.zeros(z.shape,dtype=torch.cdouble,device=DEVICE)
+# print(distance(x,x))