diff --git a/db.0 b/db.0 index f4f28fdd36ec4e301790c2f5ce2773de7d6f5518..b3657d0f04bd6e52513fa73214e54176c509b647 100644 Binary files a/db.0 and b/db.0 differ diff --git a/db.00 b/db.00 new file mode 100644 index 0000000000000000000000000000000000000000..6e67b120522c43b83db031a2f1c94ca5dee6b1fe Binary files /dev/null and b/db.00 differ diff --git a/generating.py b/generating.py index 5234ce39f5d9d9ce44ddb6361661db2b9186a0fe..43190ac570c7e4d73397282f1e4aff381836b21b 100644 --- a/generating.py +++ b/generating.py @@ -1,31 +1,35 @@ import torch -import torch.utils.data as Data -from scipy.linalg import eigvalsh -import useful_utils -import numpy as np + +# import torch.utils.data as Data +# from scipy.linalg import eigvalsh +# import numpy as np +import torch import math -import params -def gen_x_Amatrix_y(n,m, N, cplx_flag, DEVICE): +def gen_x_Amatrix_y(n, m, N, cplx_flag, cuda_flag): """create x, y, and Amatrix""" - x_1 = torch.randn((N,n), dtype=torch.double,device=DEVICE) - x_2 = torch.randn((N,n), dtype=torch.double,device=DEVICE)*1j*cplx_flag - x= x_1 + x_2 - x_1_val = torch.randn((N,n), dtype=torch.double,device=DEVICE) - x_2_val = torch.randn((N,n), dtype=torch.double,device=DEVICE)*1j*cplx_flag + if torch.cuda.is_available() & cuda_flag: + DEVICE = "cuda" + else: + DEVICE = "cpu" + x_1 = torch.randn((N, n), dtype=torch.double, device=DEVICE) + x_2 = torch.randn((N, n), dtype=torch.double, device=DEVICE) * 1j * cplx_flag + x = x_1 + x_2 + x_1_val = torch.randn((N, n), dtype=torch.double, device=DEVICE) + x_2_val = torch.randn((N, n), dtype=torch.double, device=DEVICE) * 1j * cplx_flag x_val = x_1_val + x_2_val - Amatrix_1 = torch.randn((m,n), dtype=torch.double,device=DEVICE) - Amatrix_2 = torch.randn((m,n), dtype=torch.double,device=DEVICE)*1j*cplx_flag - scale = math.sqrt(2)**cplx_flag - Amatrix = (Amatrix_1 + Amatrix_2)/scale - y_trans = torch.abs(torch.linalg.matmul(Amatrix,x.T),device=DEVICE) # y_i=|a_i x| - y_trans_val = torch.abs(torch.linalg.matmul(Amatrix,x_val.T),device=DEVICE) # y_i=|a_i x| - return x, x_val, Amatrix, y_trans.T, y_trans_val.T + Amatrix_1 = torch.randn((n, m), dtype=torch.double, device=DEVICE) + Amatrix_2 = torch.randn((n, m), dtype=torch.double, device=DEVICE) * 1j * cplx_flag + scale = math.sqrt(2) ** cplx_flag + Amatrix = (Amatrix_1 + Amatrix_2) / scale + # y =|A(x)| Element-wise absolute value of a linear operator + y = torch.abs(torch.linalg.matmul(x, Amatrix)) + # y =|A(x)| Element-wise absolute value of a linear operator + y_val = torch.abs(torch.linalg.matmul(x_val, Amatrix)) + return x, x_val, Amatrix, y, y_val # torch.save(x, 'x.pt') # torch.save(x_val, 'x_val.pt') # torch.save(Amatrix, 'Amatrix.pt') - # torch.save(y_trans.T, 'y.pt') - # torch.save(y_trans_val.T, 'y_val.pt') - -# gen_x_Amatrix_y(params.n,params.m,params.N,params.cplx_flag) \ No newline at end of file + # torch.save(y, 'y.pt') + # torch.save(y_val, 'y_val.pt') diff --git a/main.py b/main.py index 70ca37c8537f728b7a2eaab7784619d7e6b1acc7..122f90ffe0d072f731a7b35a59ff4e41c11476dc 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ we here use a small subset of it. import generating import numpy as np +import math import os import train_URWF import optuna @@ -45,9 +46,44 @@ def objective(trial): DEVICE = "cuda" else: DEVICE = "cpu" + + print(DEVICE) SCENARIO = trial.suggest_categorical("SCENARIO", ["scalar", "vector", "matrix", "vector-matrix", "tensor", "SPD-matrix"]) LR = trial.suggest_float('LR', 1e-4, 1e-2) - x,x_val, Amatrix, y, y_val = generating.gen_x_Amatrix_y(n,m,N,cplx_flag,DEVICE) + x,x_val, Amatrix, y, y_val = generating.gen_x_Amatrix_y(n,m,N,cplx_flag, DEVICE) + + def wf_solve(x,Amatrix,y, DEVICE): + def A(Amatrix,x): + result = torch.linalg.matmul(x, Amatrix) + return result + def Ah(Amatrix,x): + result = torch.linalg.matmul(x, Amatrix.H) + return result + def distance(z,x): + z = torch.flatten(z) + x = torch.flatten(x) + error = torch.linalg.norm(x - torch.exp(-1j*torch.angle(torch.matmul(x,z))) * z)/torch.linalg.norm(x) + return error + T = 800 + tau0 = 330 + npower_iter = 30 + Relerrs = np.zeros(T+1) + z0 = torch.randn((N, n), dtype=torch.cdouble,device=DEVICE) + z0 = z0/torch.linalg.norm(z0) + for tt in range(npower_iter): + z0 = Ah(Amatrix,torch.multiply(y, (A(Amatrix,z0)))) + z0 = z0/torch.linalg.norm(z0) + normest = math.sqrt(torch.sum(y))/m + z0 = normest * z0 + z = z0 + print(distance(x,z)) + + + + return 0 + wf_solve(x,Amatrix,y,DEVICE) + print("datenkar") + # err_URWF_knot, err_URWF_inf = train_URWF.train_URWF(x, x_val, Amatrix, y, y_val, SCENARIO,LR,DEVICE) # return np.average(err_URWF_inf) diff --git a/natural_images/irwf_reconstructed/irwf_reconstructed_-1_sat_phone.png b/natural_images/irwf_reconstructed/irwf_reconstructed_-1_sat_phone.png new file mode 100644 index 0000000000000000000000000000000000000000..bf7fc7137d746ffa730c5bef9213756065b041fb Binary files /dev/null and b/natural_images/irwf_reconstructed/irwf_reconstructed_-1_sat_phone.png differ diff --git a/natural_images/natural_sat_phone_irwf.ipynb b/natural_images/natural_sat_phone_irwf.ipynb index 1009b7b7dab62800a2bbc7463207ede7f893b997..4edb34483cce22bd55b42d67ee9f221a26926f50 100644 --- a/natural_images/natural_sat_phone_irwf.ipynb +++ b/natural_images/natural_sat_phone_irwf.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -30,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -46,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -100,7 +100,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -159,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -188,7 +188,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -215,19 +215,19 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "tensor(19.4844, device='cuda:0', dtype=torch.float64)\n", - "tensor(19.3103, device='cuda:0', dtype=torch.float64)\n", - "tensor(19.4632, device='cuda:0', dtype=torch.float64)\n", - "tensor(1.0007, device='cuda:0', dtype=torch.float64)\n", - "tensor(1.0007, device='cuda:0', dtype=torch.float64)\n", - "tensor(1.0008, device='cuda:0', dtype=torch.float64)\n" + "tensor(19.4858, device='cuda:0', dtype=torch.float64)\n", + "tensor(19.3127, device='cuda:0', dtype=torch.float64)\n", + "tensor(19.4663, device='cuda:0', dtype=torch.float64)\n", + "tensor(1.0009, device='cuda:0', dtype=torch.float64)\n", + "tensor(1.0009, device='cuda:0', dtype=torch.float64)\n", + "tensor(1.0009, device='cuda:0', dtype=torch.float64)\n" ] } ], @@ -263,7 +263,12 @@ "X_recons_b = distance_update(x_b,z_b)\n", "print(image_err(X_recons_r,x_r))\n", "print(image_err(X_recons_g,x_g))\n", - "print(image_err(X_recons_b,x_b))\n" + "print(image_err(X_recons_b,x_b))\n", + "X_recons[:,:,0] = distance_update(x_r,z_r)\n", + "X_recons[:,:,1] = distance_update(x_g,z_g)\n", + "X_recons[:,:,2] = distance_update(x_b,z_b)\n", + "save_image(X_recons,-1,file_name)\n", + "\n" ] }, { @@ -589,7 +594,7 @@ " ################################\n", " X_recons[:,:,0] = distance_update(x_r,z_r)\n", " err[t,0] = image_err(X_recons[:,:,0],X[:,:,0])\n", - " print(\"iteration:\",t,\"error in the red channel:\",image_err(X_recons[:,:,0],X[:,:,0]))\n", + " # print(\"iteration:\",t,\"error in the red channel:\",image_err(X_recons[:,:,0],X[:,:,0]))\n", " yz_r = A(z_r)\n", " yz_abs_r = torch.abs(yz_r)\n", " first_divide_r = torch.divide(yz_r,yz_abs_r)\n", @@ -602,7 +607,7 @@ " ################################\n", " X_recons[:,:,1] = distance_update(x_g,z_g)\n", " err[t,1] = image_err(X_recons[:,:,1],X[:,:,1])\n", - " print(\"iteration:\",t,\"error in the green channel:\",image_err(X_recons[:,:,1],X[:,:,1]))\n", + " # print(\"iteration:\",t,\"error in the green channel:\",image_err(X_recons[:,:,1],X[:,:,1]))\n", " yz_g = A(z_g)\n", " yz_abs_g = torch.abs(yz_g)\n", " first_divide_g = torch.divide(yz_g,yz_abs_g)\n", @@ -615,7 +620,7 @@ " ################################\n", " X_recons[:,:,2] = distance_update(x_b,z_b)\n", " err[t,2] = image_err(X_recons[:,:,2],X[:,:,2])\n", - " print(\"iteration:\",t,\"error in the blue channel:\",image_err(X_recons[:,:,2],X[:,:,2]))\n", + " # print(\"iteration:\",t,\"error in the blue channel:\",image_err(X_recons[:,:,2],X[:,:,2]))\n", " yz_b = A(z_b)\n", " yz_abs_b = torch.abs(yz_b)\n", " first_divide_b = torch.divide(yz_b,yz_abs_b)\n", diff --git a/natural_images/natural_sat_phone_rwf.ipynb b/natural_images/natural_sat_phone_rwf.ipynb index 2ed4923407e4245d45674dc25f0032c9ec6659a5..9f37b9aa1c5da322715923f473ab5513c2e7b9b3 100644 --- a/natural_images/natural_sat_phone_rwf.ipynb +++ b/natural_images/natural_sat_phone_rwf.ipynb @@ -222,12 +222,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor(19.4845, device='cuda:0', dtype=torch.float64)\n", - "tensor(19.3110, device='cuda:0', dtype=torch.float64)\n", - "tensor(19.4643, device='cuda:0', dtype=torch.float64)\n", - "tensor(1.0007, device='cuda:0', dtype=torch.float64)\n", + "tensor(19.4882, device='cuda:0', dtype=torch.float64)\n", + "tensor(19.3146, device='cuda:0', dtype=torch.float64)\n", + "tensor(19.4678, device='cuda:0', dtype=torch.float64)\n", + "tensor(1.0009, device='cuda:0', dtype=torch.float64)\n", "tensor(1.0010, device='cuda:0', dtype=torch.float64)\n", - "tensor(1.0008, device='cuda:0', dtype=torch.float64)\n" + "tensor(1.0010, device='cuda:0', dtype=torch.float64)\n" ] } ], @@ -263,7 +263,12 @@ "X_recons_b = distance_update(x_b,z_b)\n", "print(image_err(X_recons_r,x_r))\n", "print(image_err(X_recons_g,x_g))\n", - "print(image_err(X_recons_b,x_b))\n" + "print(image_err(X_recons_b,x_b))\n", + "X_recons[:,:,0] = distance_update(x_r,z_r)\n", + "X_recons[:,:,1] = distance_update(x_g,z_g)\n", + "X_recons[:,:,2] = distance_update(x_b,z_b)\n", + "save_image(X_recons,-1,file_name)\n", + "\n" ] }, { diff --git a/natural_images/natural_sat_phone_twf.ipynb b/natural_images/natural_sat_phone_twf.ipynb index d2af014f7f8e97cb3c8565cc31793663ef46fee0..2c85cbcf4483caf97c2f2344fbd79af3a483706c 100644 --- a/natural_images/natural_sat_phone_twf.ipynb +++ b/natural_images/natural_sat_phone_twf.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 77, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -35,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -51,7 +51,7 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -107,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -172,7 +172,7 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -190,7 +190,7 @@ }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -219,7 +219,7 @@ }, { "cell_type": "code", - "execution_count": 83, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -246,7 +246,7 @@ }, { "cell_type": "code", - "execution_count": 84, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -260,19 +260,19 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "tensor(428.3096, device='cuda:0', dtype=torch.float64)\n", - "tensor(420.7117, device='cuda:0', dtype=torch.float64)\n", - "tensor(427.4226, device='cuda:0', dtype=torch.float64)\n", - "tensor(0.9130, device='cuda:0', dtype=torch.float64)\n", - "tensor(0.6938, device='cuda:0', dtype=torch.float64)\n", - "tensor(0.7306, device='cuda:0', dtype=torch.float64)\n" + "tensor(428.4117, device='cuda:0', dtype=torch.float64)\n", + "tensor(420.8286, device='cuda:0', dtype=torch.float64)\n", + "tensor(427.5470, device='cuda:0', dtype=torch.float64)\n", + "tensor(0.5845, device='cuda:0', dtype=torch.float64)\n", + "tensor(0.6381, device='cuda:0', dtype=torch.float64)\n", + "tensor(0.9698, device='cuda:0', dtype=torch.float64)\n" ] } ], @@ -310,7 +310,10 @@ "print(image_err(X_recons_g,x_g))\n", "print(image_err(X_recons_b,x_b))\n", "\n", - "\n" + "X_recons[:,:,0] = distance_update(x_r,z_r)\n", + "X_recons[:,:,1] = distance_update(x_g,z_g)\n", + "X_recons[:,:,2] = distance_update(x_b,z_b)\n", + "save_image(X_recons,-1,file_name)" ] }, { @@ -1980,7 +1983,7 @@ " z_r_f = torch.flatten(z_r)\n", " error = torch.linalg.norm(x_r_f - torch.exp(-1j*torch.angle(torch.matmul(x_r_f,z_r_f))) * z_r_f)/torch.linalg.norm(x_r_f)\n", " return error\n", - "print(distance(X_recons,X))\n" + "# print(distance(X_recons,X))\n" ] }, { diff --git a/natural_images/natural_sat_phone_wf.ipynb b/natural_images/natural_sat_phone_wf.ipynb index d0e239b19de51b686d8d2c5fe968057555796b8e..653dda7980c6f4273c2f938541042d5eee5a3def 100644 --- a/natural_images/natural_sat_phone_wf.ipynb +++ b/natural_images/natural_sat_phone_wf.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -35,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -51,7 +51,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -107,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -124,8 +124,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "torch.Size([563, 1000, 3])\n", - "tensor(737.2259, device='cuda:0', dtype=torch.float64)\n" + "torch.Size([563, 1000, 3])\n" ] } ], @@ -163,7 +162,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -181,7 +180,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -210,7 +209,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -232,12 +231,14 @@ "z0_b = z0_b/torch.linalg.norm(z0_b,'fro') \n", "############################################\n", "X_recons = torch.zeros(X.shape,dtype=torch.cdouble,device=DEVICE)\n", + "\n", + "\n", "\n" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -255,16 +256,16 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "tensor(428.4212, device='cuda:0', dtype=torch.float64)\n", - "tensor(420.8395, device='cuda:0', dtype=torch.float64)\n", - "tensor(427.5393, device='cuda:0', dtype=torch.float64)\n" + "tensor(428.5580, device='cuda:0', dtype=torch.float64)\n", + "tensor(420.9625, device='cuda:0', dtype=torch.float64)\n", + "tensor(427.6684, device='cuda:0', dtype=torch.float64)\n" ] } ], @@ -289,6 +290,10 @@ "print(normest_g)\n", "print(normest_b)\n", "\n", + "X_recons[:,:,0] = distance_update(x_r,z_r)\n", + "X_recons[:,:,1] = distance_update(x_g,z_g)\n", + "X_recons[:,:,2] = distance_update(x_b,z_b)\n", + "save_image(X_recons,-1,file_name)\n", "\n" ] }, diff --git a/natural_images/rwf_reconstructed/rwf_reconstructed_-1_sat_phone.png b/natural_images/rwf_reconstructed/rwf_reconstructed_-1_sat_phone.png new file mode 100644 index 0000000000000000000000000000000000000000..0d8eefdb51fb18469235cc466c99b34dbb60e2eb Binary files /dev/null and b/natural_images/rwf_reconstructed/rwf_reconstructed_-1_sat_phone.png differ diff --git a/natural_images/twf_reconstructed/twf_reconstructed_-1_sat_phone.png b/natural_images/twf_reconstructed/twf_reconstructed_-1_sat_phone.png new file mode 100644 index 0000000000000000000000000000000000000000..149581a835535c20872d014eae5d5695824fc3d8 Binary files /dev/null and b/natural_images/twf_reconstructed/twf_reconstructed_-1_sat_phone.png differ diff --git a/natural_images/wf_reconstructed/wf_reconstructed_-1_sat_phone.png b/natural_images/wf_reconstructed/wf_reconstructed_-1_sat_phone.png new file mode 100644 index 0000000000000000000000000000000000000000..9c678f8d639eb9ada28f97511d801dd0469541d1 Binary files /dev/null and b/natural_images/wf_reconstructed/wf_reconstructed_-1_sat_phone.png differ diff --git a/params.py b/params.py deleted file mode 100644 index 49f7049b140c3b117bb34bc61bef1ecb336bdefa..0000000000000000000000000000000000000000 --- a/params.py +++ /dev/null @@ -1,67 +0,0 @@ -import argparse -import torch -import numpy as np -# n = 64 # signal dimension -# alpha = 10 -# m = alpha * n # number of measurements -# # m = 2 -# cplx_flag = 1 # real: cplx_flag = 0; complex: cplx_flag = 1; -# T = 80 # number of iterations -# npower_iter = 30 # number of power iterations -# N_train= 100 # Number of training samples -# EPOCHS = 1 -# scenario = 4 -# LR = 1e-3 - -# N = N_train # Number of training samples -# mu = 0.8+0.4*cplx_flag # suggested step for the Wirtinger Flow -# cuda_opt = 1 -# if torch.cuda.is_available() & cuda_opt: -# DEVICE = "cuda" -# else: -# DEVICE = "cpu" - -# batch = n - - -# scalar = True -# vector = True -# matrix = True -# tensor = True -# if scenario == 0: - # scalar = True - # vector = False - # matrix = False -# if scenario == 1: - # scalar = False - # vector = True - # matrix = False -# if scenario == 2: - # scalar = False - # vector = True - # matrix = True -# if scenario == 3: - # scalar = False - # vector = False - # matrix = True - - - - -# print(Tensor) -# parser = argparse.ArgumentParser() -# parser.add_argument("-s", help="Scenario", type=int) -# parser.add_argument("-lr", help="Learning Rate", type=float) -# parser.add_argument("-e", help="epochs", type=int) -# parser.add_argument("-n", help="SNR", type=int) -# parser.add_argument("-c", help="Scenario", type=int) -# parser.add_argument("-tstart", help="T start (Number of unfoldings)", type=int) -# parser.add_argument("-tend", help="T end (Number of unfoldings)", type=int) -# parser.add_argument("-tstep", help="T step (Number of unfoldings)", type=int) -# parser.add_argument("-ntrain", help="N Train", type=int) -# parser.add_argument("-sigsnr", help="Signal SNR", type=int) -# parser.add_argument("-epochs", help="Number of epochs", type=int) -# args = parser.parse_args() -# scenario = args.s -# LR = args.lr -# EPOCHS = args.e \ No newline at end of file diff --git a/reconstructed b/reconstructed deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/reconstructed.png b/reconstructed.png deleted file mode 100644 index 2ad830f8682a7d1c5c2fee06c801e34613fff9a2..0000000000000000000000000000000000000000 Binary files a/reconstructed.png and /dev/null differ diff --git a/without_hyper.py b/without_hyper.py index a63b7ed15c226a44f1457f44df66dde89fe0b6bd..b5108b80f8ec24b89fc491e40697b9bcadc00d0e 100644 --- a/without_hyper.py +++ b/without_hyper.py @@ -1,52 +1,92 @@ -import numpy as np -import params import torch -# import torch +import numpy as np +import math import generating -import matplotlib.pyplot as plt -import train_URWF -import train_UIRWF -import visualize -import argparse -import time -# import pathlib -import datetime -from datetime import timedelta - -start = time.time() -print("n_samples = ", params.N_train) -n_val = int(0.1 * params.N_train) -print("n_val = ", n_val) -# print("model = UIRWF") -print("scenario = ", params.scenario) -print("ML_LR = ", params.LR) -print("EPOCHS = ", params.EPOCHS) -print("device = ",params.DEVICE) -generating.gen_x_Amatrix_y(params.n,params.m,params.N,params.cplx_flag) -err_URWF_knot, err_URWF_inf = train_URWF.train_URWF() -print(np.average(err_URWF_inf)) -# print(err_URWF_knot) -# print(err_URWF_inf) -# err_UIRWF_knot, err_UIRWF_inf = train_UIRWF.train_UIRWF() -visualize.plot() - -plt.rcParams['text.usetex'] = True -plt.title("err vs epoch index") -plt.xlabel("$epochs$") -plt.ylabel(r'$\min \| z_{i}e^{-i\phi} - x_i \|_F$', fontsize=14, color='k') -plt.figure(2) -plt.semilogy(np.arange(0, params.N_train), err_URWF_knot, color ="red",label="URWF on untrained network") -plt.semilogy(np.arange(0, params.N_train), err_URWF_inf, color ="blue",label="URWF on trained network") -# plt.semilogy(np.arange(0, params.N_train), err_UIRWF_knot, color ="orange",label="UIRWF on untrained network") -# plt.semilogy(np.arange(0, params.N_train), err_UIRWF_inf, color ="green",label="UIRWF on trained network") -plt.title(r'URWF vs UIRWF on (un)trained networks',fontsize=10) -plt.legend() -plt.savefig('Scenario = '+str(params.scenario)+', Layers = '+str(params.T)+', Epochs = '+str(params.EPOCHS)+'_ LR = '+str(params.LR)+'_(un)trained.pdf') -# plt.show() - - - -end = time.time() -elapsed_time = end - start -print("elapsed time is:") -print(str(timedelta(seconds=elapsed_time)),"(HH:MM:SS)") + + +n = 64 +m = 10 * n +N = 2 +cplx_flag = 1 +cuda_flag = 1 +x, x_val, Amatrix, y, y_val = generating.gen_x_Amatrix_y(n, m, N, cplx_flag, cuda_flag) + + +def wf_solve(x, Amatrix, y, cuda_flag): + if torch.cuda.is_available() & cuda_flag: + DEVICE = "cuda" + else: + DEVICE = "cpu" + + def A(Amatrix, x): + result = torch.linalg.matmul(x, Amatrix) + return result + + def Ah(Amatrix, x): + result = torch.linalg.matmul(x, Amatrix.H) + return result + + def distance(z, x): + z = torch.flatten(z) + x = torch.flatten(x) + error = torch.linalg.norm( + x - torch.exp(-1j * torch.angle(torch.dot(torch.conj_physical(x), z))) * z + ) / torch.linalg.norm(x) + # error = torch.linalg.norm( + # x - torch.exp(-1j * torch.exp(-1j*torch.angle(torch.trace(x.H*z)))) * z + # ) / torch.linalg.norm(x) + + return error + + T = 800 + tau0 = 330 + npower_iter = 300 + # Relerrs = np.zeros(T + 1) + z0 = torch.randn(x.shape, dtype=torch.cdouble, device=DEVICE) + z0 = z0 / torch.linalg.norm(z0) + for tt in range(npower_iter): + z0 = Ah(Amatrix,torch.multiply(y, (A(Amatrix,z0)))) + z0 = z0/torch.linalg.norm(z0) + normest = math.sqrt(torch.sum(y))/m + z0 = normest * z0 + z = z0 + + print(distance(z,x)) + normest = math.sqrt(torch.sum(y**2)/m) + for tt in range(T+1): + yz = A(Amatrix , z) + grad = 1/m * Ah( Amatrix, torch.multiply(abs(yz)**2-y**2, yz)) + z = z - (min(1-math.exp(-tt/tau0), 0.2))/normest**2 * grad + print("iteration:",tt,distance(z, x)) + print("iteration:",tt,distance(z, x*1j)) + print(distance(z, x)) + # zero = torch.zeros(x.shape,dtype=torch.cdouble,device=DEVICE) + # print(distance(zero, x)) + return z + +z = wf_solve(x, Amatrix, y, cuda_flag) + + +def distance(z, x): + z = torch.flatten(z) + x = torch.flatten(x) + error = torch.linalg.norm( + x - torch.exp(-1j * torch.angle(torch.dot(torch.conj_physical(x), z))) * z + ) / torch.linalg.norm(x) + # error = torch.linalg.norm( + # x - torch.exp(-1j * torch.exp(-1j*torch.angle(torch.trace(x.H*z)))) * z + # ) / torch.linalg.norm(x) + + return error + +print(distance(x,x)) +print(distance(z,x)) +# if torch.cuda.is_available() & cuda_flag: +# DEVICE = "cuda" +# else: +# DEVICE = "cpu" +# z0 = torch.randn( n, dtype=torch.cdouble, device=DEVICE) +# z0 = z0 / torch.linalg.norm(z0) +# z = z0 +# zero = torch.zeros(z.shape,dtype=torch.cdouble,device=DEVICE) +# print(distance(x,x))