Skip to content
Snippets Groups Projects
Select Git revision
1 result Searching

config10.py

Blame
  • generating.py 1.25 KiB
    import torch
    import torch.utils.data as Data
    from scipy.linalg import eigvalsh
    import useful_utils
    import numpy as np
    import math
    import params
    
    
    def gen_x_Amatrix_y(n,m, N, cplx_flag, DEVICE):
        """create x, y, and Amatrix"""
        x_1 = torch.randn((N,n), dtype=torch.double,device=DEVICE)
        x_2 = torch.randn((N,n), dtype=torch.double,device=DEVICE)*1j*cplx_flag
        x= x_1 + x_2
        x_1_val = torch.randn((N,n), dtype=torch.double,device=DEVICE)
        x_2_val = torch.randn((N,n), dtype=torch.double,device=DEVICE)*1j*cplx_flag
        x_val = x_1_val + x_2_val
        Amatrix_1 = torch.randn((m,n), dtype=torch.double,device=DEVICE)
        Amatrix_2 = torch.randn((m,n), dtype=torch.double,device=DEVICE)*1j*cplx_flag
        scale = math.sqrt(2)**cplx_flag
        Amatrix = (Amatrix_1 + Amatrix_2)/scale
        y_trans = torch.abs(torch.linalg.matmul(Amatrix,x.T),device=DEVICE)  # y_i=|a_i x|
        y_trans_val = torch.abs(torch.linalg.matmul(Amatrix,x_val.T),device=DEVICE)  # y_i=|a_i x|
        return x, x_val, Amatrix, y_trans.T, y_trans_val.T
        # torch.save(x, 'x.pt')
        # torch.save(x_val, 'x_val.pt')
        # torch.save(Amatrix, 'Amatrix.pt')
        # torch.save(y_trans.T, 'y.pt')
        # torch.save(y_trans_val.T, 'y_val.pt')
    
    # gen_x_Amatrix_y(params.n,params.m,params.N,params.cplx_flag)