283 lines
11 KiB
Python
283 lines
11 KiB
Python
from pathlib import Path
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
# from torch.utils.data import Sampler
|
|
import numpy as np
|
|
import configparser
|
|
|
|
# class SubsetSampler(Sampler[int]):
|
|
# """
|
|
# Samples elements from a given list of indices.
|
|
|
|
# :param indices: List of indices to sample from.
|
|
# :type indices: list[int]
|
|
# """
|
|
|
|
# def __init__(self, indices):
|
|
# self.indices = indices
|
|
|
|
# def __iter__(self):
|
|
# return iter(self.indices)
|
|
|
|
# def __len__(self):
|
|
# return len(self.indices)
|
|
|
|
|
|
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
|
|
filepath = Path(config_path)
|
|
filepath = filepath.parent.glob(filepath.name)
|
|
config = configparser.ConfigParser()
|
|
config.read(filepath)
|
|
path_elements = (
|
|
config["data"]["dir"],
|
|
config["data"]["npy_dir"],
|
|
config["data"]["file"],
|
|
)
|
|
datapath = Path("/".join(path_elements).replace('"', ""))
|
|
sps = int(config["glova"]["sps"])
|
|
|
|
if symbols is None:
|
|
symbols = int(config["glova"]["nos"]) - skipfirst
|
|
|
|
data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps]
|
|
|
|
if normalize:
|
|
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
|
a, b, c, d = np.square(data.T)
|
|
a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
|
|
data = np.sqrt(np.array([a, b, c, d]).T)
|
|
|
|
if real:
|
|
data = np.abs(data)
|
|
|
|
config["glova"]["nos"] = str(symbols)
|
|
|
|
data = torch.tensor(data, device=device, dtype=dtype)
|
|
|
|
return data, config
|
|
|
|
|
|
def roll_along(arr, shifts, dim):
|
|
# https://stackoverflow.com/a/76920720
|
|
# (c) Mateen Ulhaq, 2023
|
|
# CC BY-SA 4.0
|
|
shifts = torch.tensor(shifts)
|
|
assert arr.ndim - 1 == shifts.ndim
|
|
dim %= arr.ndim
|
|
shape = (1,) * dim + (-1,) + (1,) * (arr.ndim - dim - 1)
|
|
dim_indices = torch.arange(arr.shape[dim]).reshape(shape)
|
|
indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim]
|
|
return torch.gather(arr, dim, indices)
|
|
|
|
|
|
class FiberRegenerationDataset(Dataset):
|
|
"""
|
|
Dataset for fiber regeneration training.
|
|
|
|
The dataset is loaded from a configuration file, which must contain (at least) the following sections:
|
|
```
|
|
[data]
|
|
dir = <data_dir>
|
|
npy_dir = <npy_dir>
|
|
file = <data_file>
|
|
|
|
[glova]
|
|
sps = <samples per symbol>
|
|
```
|
|
The data is loaded from the file `<data_dir>/<npy_dir>/<data_file>` and is assumed to be in the following format:
|
|
```
|
|
[ E_in_x,
|
|
E_in_y,
|
|
E_out_x,
|
|
E_out_y ]
|
|
```
|
|
|
|
The dataset is sliced into slices, where each slice consists of a (fractional) number of symbols.
|
|
The target can be delayed relative to the input data by a (fractional) number of symbols.
|
|
The x and y channels can be delayed relative to each other by a (fractional) number of symbols.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
file_path: str | Path,
|
|
symbols: int | float,
|
|
*,
|
|
output_dim: int = None,
|
|
target_delay: float | int = 0,
|
|
xy_delay: float | int = 0,
|
|
drop_first: float | int = 0,
|
|
dtype: torch.dtype = None,
|
|
real: bool = False,
|
|
device=None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initialize the dataset.
|
|
|
|
:param file_path: Path to the data file. Can contain wildcards (*). The first
|
|
:type file_path: str | pathlib.Path
|
|
:param symbols: Number of symbols in each slice. Can be a float to specify a fraction of a symbol.
|
|
:type symbols: float | int
|
|
:param data_size: Number of samples in each slice. The data is reduced by taking equally spaced samples. If unset, each slice will contain symbols*samples_per_symbol samples.
|
|
:type data_size: int, optional
|
|
:param target_delay: Delay (in fractional symbols) between data and target. A positive delay means the target is delayed relative to the data. Default is 0.
|
|
:type target_delay: float | int, optional
|
|
:param xy_delay: Delay (in fractional symbols) between the x and y channels. A positive delay means the y channel is delayed relative to the x channel. Default is 0.
|
|
:type xy_delay: float | int, optional
|
|
:param drop_first: Number of (fractional) symbols to drop from the beginning
|
|
:type drop_first: float | int
|
|
"""
|
|
|
|
# check types
|
|
assert isinstance(file_path, str), "file_path must be a string"
|
|
assert isinstance(symbols, (float, int)), "symbols must be a float or an integer"
|
|
assert output_dim is None or isinstance(output_dim, int), "output_len must be an integer"
|
|
assert isinstance(target_delay, (float, int)), "target_delay must be a float or an integer"
|
|
assert isinstance(xy_delay, (float, int)), "xy_delay must be a float or an integer"
|
|
assert isinstance(drop_first, int), "drop_first must be an integer"
|
|
|
|
# check values
|
|
assert symbols > 0, "symbols must be positive"
|
|
assert output_dim is None or output_dim > 0, "output_len must be positive or None"
|
|
assert drop_first >= 0, "drop_first must be non-negative"
|
|
|
|
faux = kwargs.pop("faux", False)
|
|
|
|
if faux:
|
|
data_raw = np.array(
|
|
[[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)],
|
|
dtype=np.complex128,
|
|
)
|
|
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
|
|
self.config = {
|
|
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
|
|
"glova": {"sps": 128},
|
|
}
|
|
else:
|
|
data_raw, self.config = load_data(
|
|
file_path,
|
|
skipfirst=drop_first,
|
|
symbols=kwargs.pop("num_symbols", None),
|
|
real=real,
|
|
normalize=True,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
|
|
self.device = data_raw.device
|
|
|
|
self.samples_per_symbol = int(self.config["glova"]["sps"])
|
|
self.samples_per_slice = int(symbols * self.samples_per_symbol)
|
|
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
|
|
|
|
self.output_dim = output_dim or self.samples_per_slice
|
|
self.target_delay = target_delay or 0
|
|
self.xy_delay = xy_delay or 0
|
|
|
|
ovrd_target_delay_samples = kwargs.pop("ovrd_target_delay_samples", None)
|
|
ovrd_xy_delay_samples = kwargs.pop("ovrd_xy_delay_samples", None)
|
|
|
|
self.target_delay_samples = (
|
|
ovrd_target_delay_samples
|
|
if ovrd_target_delay_samples is not None
|
|
else int(self.target_delay * self.samples_per_symbol)
|
|
)
|
|
self.xy_delay_samples = (
|
|
ovrd_xy_delay_samples if ovrd_xy_delay_samples is not None else int(self.xy_delay * self.samples_per_symbol)
|
|
)
|
|
|
|
# data_raw = torch.tensor(data_raw, dtype=dtype)
|
|
|
|
# data layout
|
|
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0],
|
|
# [E_in_x1, E_in_y1, E_out_x1, E_out_y1],
|
|
# ...
|
|
# [E_in_xN, E_in_yN, E_out_xN, E_out_yN] ]
|
|
|
|
data_raw = data_raw.transpose(0, 1)
|
|
|
|
# data layout
|
|
# [ E_in_x[0:N],
|
|
# E_in_y[0:N],
|
|
# E_out_x[0:N],
|
|
# E_out_y[0:N] ]
|
|
|
|
# shift x data by xy_delay_samples relative to the y data (example value: 3)
|
|
# [ E_in_x [0:N], [ E_in_x [ 0:N ], [ E_in_x [3:N ],
|
|
# E_in_y [0:N], -> E_in_y [-3:N-3], -> E_in_y [0:N-3],
|
|
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[3:N ],
|
|
# E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ]
|
|
|
|
if self.xy_delay_samples != 0:
|
|
data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1)
|
|
if self.xy_delay_samples > 0:
|
|
data_raw = data_raw[:, self.xy_delay_samples :]
|
|
elif self.xy_delay_samples < 0:
|
|
data_raw = data_raw[:, : self.xy_delay_samples]
|
|
|
|
# shift fiber input data (target) by target_delay_samples relative to the fiber output data (input)
|
|
# (example value: 5)
|
|
# [ E_in_x [0:N], [ E_in_x [-5:N-5], [ E_in_x [0:N-5],
|
|
# E_in_y [0:N], -> E_in_y [-5:N-5], -> E_in_y [0:N-5],
|
|
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[5:N ],
|
|
# E_out_y[0:N] ] E_out_y[ 0:N ] ] E_out_y[5:N ] ]
|
|
|
|
if self.target_delay_samples != 0:
|
|
data_raw = roll_along(
|
|
data_raw,
|
|
[self.target_delay_samples, self.target_delay_samples, 0, 0],
|
|
dim=1,
|
|
)
|
|
if self.target_delay_samples > 0:
|
|
data_raw = data_raw[:, self.target_delay_samples :]
|
|
elif self.target_delay_samples < 0:
|
|
data_raw = data_raw[:, : self.target_delay_samples]
|
|
|
|
data_raw = data_raw.view(2, 2, -1)
|
|
# data layout
|
|
# [ [E_in_x, E_in_y],
|
|
# [E_out_x, E_out_y] ]
|
|
|
|
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
|
self.data = self.data.movedim(-2, 0)
|
|
# -> [no_slices, 2, 2, samples_per_slice]
|
|
|
|
# data layout
|
|
# [
|
|
# [ [E_in_x[0:N+0], E_in_y[0:N+0] ], [ E_out_x[0:N+0], E_out_y[0:N+0] ] ],
|
|
# [ [E_in_x[1:N+1], E_in_y[1:N+1] ], [ E_out_x[1:N+1], E_out_y[1:N+1] ] ],
|
|
# ...
|
|
# ] -> [no_slices, 2, 2, samples_per_slice]
|
|
|
|
...
|
|
|
|
def __len__(self):
|
|
return self.data.shape[0]
|
|
|
|
def __getitem__(self, idx):
|
|
if isinstance(idx, slice):
|
|
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
|
|
else:
|
|
data, target = self.data[idx, 1].squeeze(), self.data[idx, 0].squeeze()
|
|
|
|
# reduce by by taking self.output_dim equally spaced samples
|
|
data = data[:, : data.shape[1] // self.output_dim * self.output_dim]
|
|
data = data.view(data.shape[0], self.output_dim, -1)
|
|
data = data[:, :, 0]
|
|
|
|
# target is corresponding to the middle of the data as the output sample is influenced by the data before and after it
|
|
target = target[:, : target.shape[1] // self.output_dim * self.output_dim]
|
|
target = target.view(target.shape[0], self.output_dim, -1)
|
|
target = target[:, 0, target.shape[2] // 2]
|
|
|
|
data = data.transpose(0, 1).flatten().squeeze()
|
|
target = target.flatten().squeeze()
|
|
|
|
# data layout:
|
|
# [sample_x0, sample_y0, sample_x1, sample_y1, ...]
|
|
# target layout:
|
|
# [sample_x0, sample_y0]
|
|
|
|
return data, target
|