Skip to content
Snippets Groups Projects
Commit 71a37c4f authored by Yarkin Colak's avatar Yarkin Colak
Browse files

Initial commit

parents
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn
import glob
import numpy as np
from torch.utils.data import Dataset
import pytorch_lightning as pl
class PhaseDataset(pl.LightningDataModule):
def __init__(self, input_folder, target_folder, transform=None):
super().__init__()
self.input_folder = input_folder
self.target_folder = target_folder
self.transform = transform
def setup(self, stage=None):
input_files = sorted(glob.glob(self.input_folder))
target_files = sorted(glob.glob(self.target_folder))
# By sorting the files, we establish a fixed order for loading
# and processing the files, making it easier to match inputs and
# targets correctly.
self.input_data = []
self.target_data = []
for input_file in input_files:
tensor = torch.flatten(torch.from_numpy(np.load(input_file)))
self.input_data.append(tensor)
# ".append()" adds the specified tensor object to the end of the list.
# In this case, it adds the processed tensor (representing a single
# input sample) to the self.input_data list.
for target_file in target_files:
tensor = torch.flatten(torch.from_numpy(np.load(target_file)))
self.target_data.append(tensor)
def __len__(self):
return len(self.input_data)
def __getitem__(self, idx):
input_sample = self.input_data[idx]
target_sample = self.target_data[idx]
if self.transform:
target_sample = self.transform(target_sample)
return input_sample.float(), target_sample.float()
#------------------------------------------------------------------------------------
class PhaseDataset2(pl.LightningDataModule):
def __init__(self, input_folder, target_folder, transform=None):
super().__init__()
self.input_folder = input_folder
self.target_folder = target_folder
self.transform = transform
def setup(self, stage=None):
input_files = sorted(glob.glob(self.input_folder))
target_files = sorted(glob.glob(self.target_folder))
self.input_data = []
self.target_data = []
for input_file in input_files:
tensor = torch.from_numpy(np.load(input_file))
self.input_data.append(tensor)
for target_file in target_files:
tensor = torch.flatten(torch.from_numpy(np.load(target_file)))
self.target_data.append(tensor)
def __len__(self):
return len(self.input_data)
def __getitem__(self, idx):
input_sample = self.input_data[idx]
target_sample = self.target_data[idx]
if self.transform:
target_sample = self.transform(target_sample)
return input_sample.float(), target_sample.float()
#------------------------------------------------------------------------------------
def sinus(x):
y = torch.sin(x)
return y
def arctan(x):
y = torch.atan(x)
return y
#------------------------------------------------------------------------------
"""
control the LAST_lightning_dataSets.py / Yarkin
"""
dataset = PhaseDataset("C:\\Users\\yrknc\\OneDrive\\Masaüstü\\Input/*.npy", "C:\\Users\\yrknc\\OneDrive\\Masaüstü\\Target/*.npy", transform=None)
# Call the setup method to load the data
dataset.setup()
# Access individual samples
input_sample, target_sample = dataset[0]
print(input_sample, "\nthis is length of input_sample in lightning: ", len(input_sample))
print(target_sample, "\nthis is length of target_sample in lightning: ", len(target_sample))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment