add: regen.py (main hyperparameter training file)
feat: add utility functions for fiber dataset visualization and hyperparameter training; housekeeping: rename dataset.py -> datasets.py
This commit is contained in:
256
src/single-core-regen/util/datasets.py
Normal file
256
src/single-core-regen/util/datasets.py
Normal file
@@ -0,0 +1,256 @@
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
import configparser
|
||||
|
||||
|
||||
def load_data(config_path, skipfirst=0, num_symbols=None):
|
||||
filepath = Path(config_path)
|
||||
filepath = filepath.parent.glob(filepath.name)
|
||||
config = configparser.ConfigParser()
|
||||
config.read(filepath)
|
||||
path_elements = (
|
||||
config["data"]["dir"],
|
||||
config["data"]["npy_dir"],
|
||||
config["data"]["file"],
|
||||
)
|
||||
datapath = Path("/".join(path_elements).replace('"', ""))
|
||||
sps = int(config["glova"]["sps"])
|
||||
|
||||
if num_symbols is None:
|
||||
num_symbols = int(config["glova"]["nos"]) - skipfirst
|
||||
|
||||
data = np.load(datapath)[skipfirst * sps : num_symbols * sps + skipfirst * sps]
|
||||
config["glova"]["nos"] = str(num_symbols)
|
||||
|
||||
return data, config
|
||||
|
||||
|
||||
def roll_along(arr, shifts, dim):
|
||||
# https://stackoverflow.com/a/76920720
|
||||
# (c) Mateen Ulhaq, 2023
|
||||
# CC BY-SA 4.0
|
||||
shifts = torch.tensor(shifts)
|
||||
assert arr.ndim - 1 == shifts.ndim
|
||||
dim %= arr.ndim
|
||||
shape = (1,) * dim + (-1,) + (1,) * (arr.ndim - dim - 1)
|
||||
dim_indices = torch.arange(arr.shape[dim]).reshape(shape)
|
||||
indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim]
|
||||
return torch.gather(arr, dim, indices)
|
||||
|
||||
|
||||
class FiberRegenerationDataset(Dataset):
|
||||
"""
|
||||
Dataset for fiber regeneration training.
|
||||
|
||||
The dataset is loaded from a configuration file, which must contain (at least) the following sections:
|
||||
```
|
||||
[data]
|
||||
dir = <data_dir>
|
||||
npy_dir = <npy_dir>
|
||||
file = <data_file>
|
||||
|
||||
[glova]
|
||||
sps = <samples per symbol>
|
||||
```
|
||||
The data is loaded from the file `<data_dir>/<npy_dir>/<data_file>` and is assumed to be in the following format:
|
||||
```
|
||||
[ E_in_x,
|
||||
E_in_y,
|
||||
E_out_x,
|
||||
E_out_y ]
|
||||
```
|
||||
|
||||
The dataset is sliced into slices, where each slice consists of a (fractional) number of symbols.
|
||||
The target can be delayed relative to the input data by a (fractional) number of symbols.
|
||||
The x and y channels can be delayed relative to each other by a (fractional) number of symbols.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str | Path,
|
||||
symbols: int | float,
|
||||
*,
|
||||
data_size: int = None,
|
||||
target_delay: float | int = 0,
|
||||
xy_delay: float | int = 0,
|
||||
drop_first: float | int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the dataset.
|
||||
|
||||
:param file_path: Path to the data file. Can contain wildcards (*). The first
|
||||
:type file_path: str | pathlib.Path
|
||||
:param symbols: Number of symbols in each slice. Can be a float to specify a fraction of a symbol.
|
||||
:type symbols: float | int
|
||||
:param data_size: Number of samples in each slice. The data is reduced by taking equally spaced samples. If unset, each slice will contain symbols*samples_per_symbol samples.
|
||||
:type data_size: int, optional
|
||||
:param target_delay: Delay (in fractional symbols) between data and target. A positive delay means the target is delayed relative to the data. Default is 0.
|
||||
:type target_delay: float | int, optional
|
||||
:param xy_delay: Delay (in fractional symbols) between the x and y channels. A positive delay means the y channel is delayed relative to the x channel. Default is 0.
|
||||
:type xy_delay: float | int, optional
|
||||
:param drop_first: Number of (fractional) symbols to drop from the beginning
|
||||
:type drop_first: float | int
|
||||
"""
|
||||
|
||||
# check types
|
||||
assert isinstance(file_path, str), "file_path must be a string"
|
||||
assert isinstance(symbols, (float, int)), (
|
||||
"symbols must be a float or an integer"
|
||||
)
|
||||
assert data_size is None or isinstance(data_size, int), (
|
||||
"output_len must be an integer"
|
||||
)
|
||||
assert isinstance(target_delay, (float, int)), (
|
||||
"target_delay must be a float or an integer"
|
||||
)
|
||||
assert isinstance(xy_delay, (float, int)), (
|
||||
"xy_delay must be a float or an integer"
|
||||
)
|
||||
assert isinstance(drop_first, int), "drop_first must be an integer"
|
||||
|
||||
# check values
|
||||
assert symbols > 0, "symbols must be positive"
|
||||
assert data_size is None or data_size > 0, "output_len must be positive or None"
|
||||
assert drop_first >= 0, "drop_first must be non-negative"
|
||||
|
||||
faux = kwargs.pop("faux", False)
|
||||
|
||||
if faux:
|
||||
data_raw = np.array(
|
||||
[[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)],
|
||||
dtype=np.complex128,
|
||||
)
|
||||
self.config = {
|
||||
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
|
||||
"glova": {"sps": 128},
|
||||
}
|
||||
else:
|
||||
data_raw, self.config = load_data(file_path)
|
||||
|
||||
self.samples_per_symbol = int(self.config["glova"]["sps"])
|
||||
self.samples_per_slice = int(symbols * self.samples_per_symbol)
|
||||
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
|
||||
|
||||
self.data_size = data_size or self.samples_per_slice
|
||||
self.target_delay = target_delay or 0
|
||||
self.xy_delay = xy_delay or 0
|
||||
|
||||
ovrd_target_delay_samples = kwargs.pop("ovrd_target_delay_samples", None)
|
||||
ovrd_xy_delay_samples = kwargs.pop("ovrd_xy_delay_samples", None)
|
||||
ovrd_drop_first_samples = kwargs.pop("ovrd_drop_first_samples", None)
|
||||
|
||||
self.target_delay_samples = (
|
||||
ovrd_target_delay_samples
|
||||
if ovrd_target_delay_samples is not None
|
||||
else int(self.target_delay * self.samples_per_symbol)
|
||||
)
|
||||
self.xy_delay_samples = (
|
||||
ovrd_xy_delay_samples
|
||||
if ovrd_xy_delay_samples is not None
|
||||
else int(self.xy_delay * self.samples_per_symbol)
|
||||
)
|
||||
drop_first_samples = (
|
||||
ovrd_drop_first_samples
|
||||
if ovrd_drop_first_samples is not None
|
||||
else int(drop_first * self.samples_per_symbol)
|
||||
)
|
||||
|
||||
# drop samples from the beginning
|
||||
data_raw = data_raw[drop_first_samples:]
|
||||
|
||||
# data layout
|
||||
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0],
|
||||
# [E_in_x1, E_in_y1, E_out_x1, E_out_y1],
|
||||
# ...
|
||||
# [E_in_xN, E_in_yN, E_out_xN, E_out_yN] ]
|
||||
|
||||
data_raw = data_raw.transpose(0, 1)
|
||||
|
||||
# data layout
|
||||
# [ E_in_x[0:N],
|
||||
# E_in_y[0:N],
|
||||
# E_out_x[0:N],
|
||||
# E_out_y[0:N] ]
|
||||
|
||||
# shift x data by xy_delay_samples relative to the y data (example value: 3)
|
||||
# [ E_in_x [0:N], [ E_in_x [ 0:N ], [ E_in_x [3:N ],
|
||||
# E_in_y [0:N], -> E_in_y [-3:N-3], -> E_in_y [0:N-3],
|
||||
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[3:N ],
|
||||
# E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ]
|
||||
|
||||
if self.xy_delay_samples != 0:
|
||||
data_raw = roll_along(
|
||||
data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1
|
||||
)
|
||||
if self.xy_delay_samples > 0:
|
||||
data_raw = data_raw[:, self.xy_delay_samples :]
|
||||
elif self.xy_delay_samples < 0:
|
||||
data_raw = data_raw[:, : self.xy_delay_samples]
|
||||
|
||||
# shift fiber input data (target) by target_delay_samples relative to the fiber output data (input)
|
||||
# (example value: 5)
|
||||
# [ E_in_x [0:N], [ E_in_x [-5:N-5], [ E_in_x [0:N-5],
|
||||
# E_in_y [0:N], -> E_in_y [-5:N-5], -> E_in_y [0:N-5],
|
||||
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[5:N ],
|
||||
# E_out_y[0:N] ] E_out_y[ 0:N ] ] E_out_y[5:N ] ]
|
||||
|
||||
if self.target_delay_samples != 0:
|
||||
data_raw = roll_along(
|
||||
data_raw,
|
||||
[self.target_delay_samples, self.target_delay_samples, 0, 0],
|
||||
dim=1,
|
||||
)
|
||||
if self.target_delay_samples > 0:
|
||||
data_raw = data_raw[:, self.target_delay_samples :]
|
||||
elif self.target_delay_samples < 0:
|
||||
data_raw = data_raw[:, : self.target_delay_samples]
|
||||
|
||||
data_raw = data_raw.view(2, 2, -1)
|
||||
# data layout
|
||||
# [ [E_in_x, E_in_y],
|
||||
# [E_out_x, E_out_y] ]
|
||||
|
||||
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||
self.data = self.data.movedim(-2, 0)
|
||||
# -> [no_slices, 2, 2, samples_per_slice]
|
||||
|
||||
# data layout
|
||||
# [
|
||||
# [ [E_in_x[0:N+0], E_in_y[0:N+0] ], [ E_out_x[0:N+0], E_out_y[0:N+0] ] ],
|
||||
# [ [E_in_x[1:N+1], E_in_y[1:N+1] ], [ E_out_x[1:N+1], E_out_y[1:N+1] ] ],
|
||||
# ...
|
||||
# ] -> [no_slices, 2, 2, samples_per_slice]
|
||||
|
||||
...
|
||||
|
||||
def __len__(self):
|
||||
return self.data.shape[0]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, slice):
|
||||
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
|
||||
else:
|
||||
data, target = self.data[idx, 1].squeeze(), self.data[idx, 0].squeeze()
|
||||
|
||||
# reduce by by taking self.output_dim equally spaced samples
|
||||
data = data[:, : data.shape[1] // self.data_size * self.data_size]
|
||||
data = data.view(data.shape[0], self.data_size, -1)
|
||||
data = data[:, :, 0]
|
||||
|
||||
# target is corresponding to the latest data point -> try to regenerate that
|
||||
target = target[:, : target.shape[1] // self.data_size * self.data_size]
|
||||
target = target.view(target.shape[0], self.data_size, -1)
|
||||
target = target[:, 0, 0]
|
||||
|
||||
data = data.transpose(0, 1).flatten().squeeze()
|
||||
target = target.flatten().squeeze()
|
||||
|
||||
# data layout:
|
||||
# [sample_x0, sample_y0, sample_x1, sample_y1, ...]
|
||||
# target layout:
|
||||
# [sample_x0, sample_y0]
|
||||
|
||||
return data, target
|
||||
Reference in New Issue
Block a user