add SlicedDataset class and utility scripts; refactor: remove _path_fix.py and update imports;

This commit is contained in:
Joseph Hopfmüller
2024-11-17 01:04:33 +01:00
parent 90aa6dbaf8
commit 87f40fc37c
7 changed files with 172 additions and 11 deletions

View File

@@ -1,9 +0,0 @@
import sys
from pathlib import Path
# hack to add the parent directory to the path -> pypho doesn't have to be installed as package
parent_dir = Path(__file__).parent
while not (parent_dir / "pypho" / "pypho").exists() and parent_dir != Path("/"):
parent_dir = parent_dir.parent
print(f"Adding '{parent_dir / "pypho"}' to 'sys.path' to enable import of '{parent_dir / 'pypho' / 'pypho'}'")
sys.path.append(str(parent_dir / "pypho"))

View File

@@ -19,9 +19,8 @@ import time
from matplotlib import pyplot as plt # noqa: F401 from matplotlib import pyplot as plt # noqa: F401
import numpy as np import numpy as np
import _path_fix # noqa: F401 import path_fix
import pypho import pypho
# import inspect
default_config = f""" default_config = f"""
[glova] [glova]
@@ -498,6 +497,7 @@ def plot_eye_diagram(
if __name__ == "__main__": if __name__ == "__main__":
path_fix.show_log()
config = get_config() config = get_config()
length_ranges = [1000, 10000] length_ranges = [1000, 10000]

View File

@@ -0,0 +1,51 @@
# move into dir single-core-regen before running
from util.dataset import SlicedDataset
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import numpy as np
def eye_dataset(dataset, no_symbols=None, offset=False, show=True):
if no_symbols is None:
no_symbols = len(dataset)
_, axs = plt.subplots(2,2, sharex=True, sharey=True)
xaxis = np.linspace(0,dataset.symbols_per_slice,dataset.samples_per_slice)
roll = dataset.samples_per_symbol//2 if offset else 0
for E_out, E_in in dataset[roll:dataset.samples_per_symbol*no_symbols+roll:dataset.samples_per_symbol]:
E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1]
axs[0,0].plot(xaxis, np.abs( E_in_x.numpy())**2, alpha=0.05, color='C0')
axs[1,0].plot(xaxis, np.abs( E_in_y.numpy())**2, alpha=0.05, color='C0')
axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2, alpha=0.05, color='C0')
axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2, alpha=0.05, color='C0')
if show:
plt.show()
# def plt_dataloader(dataloader, show=True):
# _, axs = plt.subplots(2,2, sharex=True, sharey=True)
# E_outs, E_ins = next(iter(dataloader))
# for i, (E_out, E_in) in enumerate(zip(E_outs, E_ins)):
# xaxis = np.linspace(dataset.symbols_per_slice*i,dataset.symbols_per_slice+dataset.symbols_per_slice*i,dataset.samples_per_slice)
# E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1]
# axs[0,0].plot(xaxis, np.abs(E_in_x.numpy())**2)
# axs[1,0].plot(xaxis, np.abs(E_in_y.numpy())**2)
# axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2)
# axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2)
# if show:
# plt.show()
if __name__ == "__main__":
dataset = SlicedDataset("data/20241115-175517-128-16384-10000-0-0-17-0-PAM4-0.ini", symbols=1, drop_first=100)
print(dataset[0][0].shape)
eye_dataset(dataset, 1000, offset=True, show=False)
train_loader = DataLoader(dataset, batch_size=10, shuffle=False)
# plt_dataloader(train_loader, show=False)
plt.show()

View File

@@ -0,0 +1,53 @@
from pathlib import Path
import torch
from torch.utils.data import Dataset
import numpy as np
import configparser
class SlicedDataset(Dataset):
def __init__(self, config_path, symbols, drop_first=0):
"""
Initialize the dataset.
:param config_path: Path to the configuration file
:type config_path: str
:param out_size: Output size in symbols
:type out_size: int
:param reduce: Reduce the dataset size by taking every reduce-th sample
:type reduce: int
"""
self.config = configparser.ConfigParser()
self.config.read(Path(config_path))
self.data_path = (Path(self.config['data']['dir'].strip('"')) / (self.config['data']['npy_dir'].strip('"')) / self.config['data']['file'].strip('"'))
self.symbols_per_slice = symbols
self.samples_per_symbol = int(self.config['glova']['sps'])
self.samples_per_slice = self.symbols_per_slice * self.samples_per_symbol
data_raw = torch.tensor(np.load(self.data_path))[drop_first*self.samples_per_symbol:]
data_raw = data_raw.transpose(0,1)
data_raw = data_raw.view(2,2,-1)
# [no_samples, 4] -> [4, no_samples] -> [2, 2, no_samples]
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.data = self.data.movedim(-2, 0)
# -> [no_slices, 2, 2, samples_per_slice]
...
def __len__(self):
return self.data.shape[0]
def __getitem__(self, idx):
if isinstance(idx, slice):
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
else:
return (self.data[idx,1].squeeze(), self.data[idx,0].squeeze())
if __name__ == "__main__":
pass

29
src/util/add_parent.py Normal file
View File

@@ -0,0 +1,29 @@
# add_parent.py
#
# This file is part of the repo "optical-regeneration"
# https://git.suuppl.dev/seppl/optical-regeneration.git
#
# (c) Joseph Hopfmüller, 2024
# Licensed under the EUPL
#
# Full license text in LICENSE file
###
# copy this file into the directory where you want to use pypho
import sys
from pathlib import Path
__log = []
# add the dir above the one where this file lives
__parent_dir = Path(__file__).parent
sys.path.append(str(__parent_dir.parent))
__log.append(f"Added '{__parent_dir.parent}' to 'PATH'")
def show_log():
for entry in __log:
print(entry)

37
src/util/add_pypho.py Normal file
View File

@@ -0,0 +1,37 @@
# add_pypho.py
#
# This file is part of the repo "optical-regeneration"
# https://git.suuppl.dev/seppl/optical-regeneration.git
#
# (c) Joseph Hopfmüller, 2024
# Licensed under the EUPL
#
# Full license text in LICENSE file
###
# copy this file into the directory where you want to use pypho
import sys
from pathlib import Path
__log = []
# add the dir above the one where this file lives
__parent_dir = Path(__file__).parent
# search for a dir containing ./pypho/pypho, then add the lower ./pypho
while not (__parent_dir / "pypho" / "pypho").exists() and __parent_dir != Path("/"):
__parent_dir = __parent_dir.parent
if __parent_dir != Path("/"):
sys.path.append(str(__parent_dir / "pypho"))
__log.append(f"Added '{__parent_dir/ "pypho"}' to 'PATH'")
else:
__log.append('pypho not found')
def show_log():
for entry in __log:
print(entry)