Skip to content
Snippets Groups Projects
Commit fd106d79 authored by Ali Darijani's avatar Ali Darijani
Browse files

commit

parent 5bfd10e6
No related branches found
No related tags found
No related merge requests found
*.pt
*.npy
This diff is collapsed.
This diff is collapsed.
natural_images/output_image.png

3.61 MiB

output.png

24.3 KiB

This diff is collapsed.
This diff is collapsed.
import tikzplotlib
import wf_variants
import torch
import numpy as np
......@@ -7,6 +8,22 @@ from pprint import pprint
import time
from datetime import timedelta
start = time.time()
def tikzplotlib_fix_ncols(obj):
"""
workaround for matplotlib 3.6 renamed legend's _ncol to _ncols, which breaks tikzplotlib
"""
if hasattr(obj, "_ncols"):
obj._ncol = obj._ncols
for child in obj.get_children():
tikzplotlib_fix_ncols(child)
# def tikzplotlib_fix_ncols(obj):
# """
# workaround for matplotlib 3.6 renamed legend's _ncol to _ncols, which breaks tikzplotlib
# """
# if hasattr(obj, "_ncols"):
# obj._ncol = obj._ncols
# for child in obj.get_children():
# tikzplotlib_fix_ncols(child)
# Python using Pytorch version of https://github.com/hubevan/reshaped-Wirtinger-flow/blob/master/exampleRWF.m
# Example of the RWF and IRWF algorithm under 1D Gaussian designs
# The code below is adapted from implementation of the TWF desinged by Y. Chen and E. Candes, Wirtinger Flow algorithm designed and implemented by E. Candes, X. Li, and M. Soltanolkotabi
......@@ -23,17 +40,23 @@ rwf_err = wf_variants.rwf(A,x,y)
irwf_err = wf_variants.irwf(A,x,y,batch= 1)
imrwf_err = wf_variants.irwf(A,x,y,batch= 64)
plt.semilogy(np.arange(0, np.size(wf_err)), wf_err, '-r',label="Wirtinger Flow")
plt.semilogy(np.arange(0, np.size(twf_err)), twf_err, '-m',label="Truncated Wirtinger Flow")
plt.semilogy(np.arange(0, np.size(wf_err)), wf_err, '-r',label="WF")
plt.semilogy(np.arange(0, np.size(twf_err)), twf_err, '-m',label="TWF")
# plt.semilogy(np.arange(0, np.size(itwf_err)), itwf_err, '-b')
plt.semilogy(np.arange(0, np.size(rwf_err)), rwf_err, '-c',label="Reshaped Wirtinger Flow")
plt.semilogy(np.arange(0, np.size(irwf_err)), irwf_err, '-g',label="Incrementally Reshaped Wirtinger Flow")
plt.semilogy(np.arange(0, np.size(imrwf_err)), imrwf_err, '-y',label="Incrementally Minibatched Reshaped Wirtinger Flow")
plt.semilogy(np.arange(0, np.size(rwf_err)), rwf_err, '-c',label="RWF")
plt.semilogy(np.arange(0, np.size(irwf_err)), irwf_err, '-g',label="IRWF")
plt.semilogy(np.arange(0, np.size(imrwf_err)), imrwf_err, '-y',label="IMRWF")
plt.xlabel('Iteration')
plt.ylabel('Relative error (log10)')
plt.title('Wirtinger Flow Variants')
plt.legend()
plt.savefig('Wirtinger Flow Variants.pdf')
# fig = plt.figure()
fig = plt.gcf()
tikzplotlib_fix_ncols(fig)
tikzplotlib.save("wf_variants.tex")
# plt.savefig('Wirtinger Flow Variants.pdf')
# tikzplotlib.save("wf_variants.tex")
print("misson accomplished in:")
end = time.time()
elapsed_time = end - start
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -13,7 +13,7 @@ params['n1'] = 64 # signal dimension
params['m'] = 10 * params['n1'] # number of measurements
params['cplx_flag'] = 1 # real: cplx_flag = 0; complex: cplx_flag = 1;
params['grad_type'] = 'TWF_Poiss' # 'TWF_Poiss': Poisson likelihood
params['T'] = 1500 # number of iterations
params['T'] = 800 # number of iterations
params['npower_iter'] = 30 # number of power iterations
npower_iter = params['npower_iter'] # Number of power iterations
tau0 = 330
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment