from pathlib import Path import torch from torch.utils.data import Dataset # from torch.utils.data import Sampler import numpy as np import configparser # class SubsetSampler(Sampler[int]): # """ # Samples elements from a given list of indices. # :param indices: List of indices to sample from. # :type indices: list[int] # """ # def __init__(self, indices): # self.indices = indices # def __iter__(self): # return iter(self.indices) # def __len__(self): # return len(self.indices) def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None): filepath = Path(config_path) filepath = filepath.parent.glob(filepath.name) config = configparser.ConfigParser() config.read(filepath) path_elements = ( config["data"]["dir"], config["data"]["npy_dir"], config["data"]["file"], ) datapath = Path("/".join(path_elements).replace('"', "")) sps = int(config["glova"]["sps"]) if symbols is None: symbols = int(config["glova"]["nos"]) - skipfirst data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps] if normalize: # square gets normalized to 1, as the power is (proportional to) the square of the amplitude a, b, c, d = np.square(data.T) a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d)) data = np.sqrt(np.array([a, b, c, d]).T) if real: data = np.abs(data) config["glova"]["nos"] = str(symbols) data = torch.tensor(data, device=device, dtype=dtype) return data, config def roll_along(arr, shifts, dim): # https://stackoverflow.com/a/76920720 # (c) Mateen Ulhaq, 2023 # CC BY-SA 4.0 shifts = torch.tensor(shifts) assert arr.ndim - 1 == shifts.ndim dim %= arr.ndim shape = (1,) * dim + (-1,) + (1,) * (arr.ndim - dim - 1) dim_indices = torch.arange(arr.shape[dim]).reshape(shape) indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim] return torch.gather(arr, dim, indices) class FiberRegenerationDataset(Dataset): """ Dataset for fiber regeneration training. The dataset is loaded from a configuration file, which must contain (at least) the following sections: ``` [data] dir = npy_dir = file = [glova] sps = ``` The data is loaded from the file `//` and is assumed to be in the following format: ``` [ E_in_x, E_in_y, E_out_x, E_out_y ] ``` The dataset is sliced into slices, where each slice consists of a (fractional) number of symbols. The target can be delayed relative to the input data by a (fractional) number of symbols. The x and y channels can be delayed relative to each other by a (fractional) number of symbols. """ def __init__( self, file_path: str | Path, symbols: int | float, *, output_dim: int = None, target_delay: float | int = 0, xy_delay: float | int = 0, drop_first: float | int = 0, dtype: torch.dtype = None, real: bool = False, device=None, **kwargs, ): """ Initialize the dataset. :param file_path: Path to the data file. Can contain wildcards (*). The first :type file_path: str | pathlib.Path :param symbols: Number of symbols in each slice. Can be a float to specify a fraction of a symbol. :type symbols: float | int :param data_size: Number of samples in each slice. The data is reduced by taking equally spaced samples. If unset, each slice will contain symbols*samples_per_symbol samples. :type data_size: int, optional :param target_delay: Delay (in fractional symbols) between data and target. A positive delay means the target is delayed relative to the data. Default is 0. :type target_delay: float | int, optional :param xy_delay: Delay (in fractional symbols) between the x and y channels. A positive delay means the y channel is delayed relative to the x channel. Default is 0. :type xy_delay: float | int, optional :param drop_first: Number of (fractional) symbols to drop from the beginning :type drop_first: float | int """ # check types assert isinstance(file_path, str), "file_path must be a string" assert isinstance(symbols, (float, int)), "symbols must be a float or an integer" assert output_dim is None or isinstance(output_dim, int), "output_len must be an integer" assert isinstance(target_delay, (float, int)), "target_delay must be a float or an integer" assert isinstance(xy_delay, (float, int)), "xy_delay must be a float or an integer" assert isinstance(drop_first, int), "drop_first must be an integer" # check values assert symbols > 0, "symbols must be positive" assert output_dim is None or output_dim > 0, "output_len must be positive or None" assert drop_first >= 0, "drop_first must be non-negative" faux = kwargs.pop("faux", False) if faux: data_raw = np.array( [[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)], dtype=np.complex128, ) data_raw = torch.tensor(data_raw, device=device, dtype=dtype) self.config = { "data": {"dir": '"."', "npy_dir": '"."', "file": "faux"}, "glova": {"sps": 128}, } else: data_raw, self.config = load_data( file_path, skipfirst=drop_first, symbols=kwargs.pop("num_symbols", None), real=real, normalize=True, device=device, dtype=dtype, ) self.device = data_raw.device self.samples_per_symbol = int(self.config["glova"]["sps"]) self.samples_per_slice = int(symbols * self.samples_per_symbol) self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol self.output_dim = output_dim or self.samples_per_slice self.target_delay = target_delay or 0 self.xy_delay = xy_delay or 0 ovrd_target_delay_samples = kwargs.pop("ovrd_target_delay_samples", None) ovrd_xy_delay_samples = kwargs.pop("ovrd_xy_delay_samples", None) self.target_delay_samples = ( ovrd_target_delay_samples if ovrd_target_delay_samples is not None else int(self.target_delay * self.samples_per_symbol) ) self.xy_delay_samples = ( ovrd_xy_delay_samples if ovrd_xy_delay_samples is not None else int(self.xy_delay * self.samples_per_symbol) ) # data_raw = torch.tensor(data_raw, dtype=dtype) # data layout # [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0], # [E_in_x1, E_in_y1, E_out_x1, E_out_y1], # ... # [E_in_xN, E_in_yN, E_out_xN, E_out_yN] ] data_raw = data_raw.transpose(0, 1) # data layout # [ E_in_x[0:N], # E_in_y[0:N], # E_out_x[0:N], # E_out_y[0:N] ] # shift x data by xy_delay_samples relative to the y data (example value: 3) # [ E_in_x [0:N], [ E_in_x [ 0:N ], [ E_in_x [3:N ], # E_in_y [0:N], -> E_in_y [-3:N-3], -> E_in_y [0:N-3], # E_out_x[0:N], E_out_x[ 0:N ], E_out_x[3:N ], # E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ] if self.xy_delay_samples != 0: data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1) if self.xy_delay_samples > 0: data_raw = data_raw[:, self.xy_delay_samples :] elif self.xy_delay_samples < 0: data_raw = data_raw[:, : self.xy_delay_samples] # shift fiber input data (target) by target_delay_samples relative to the fiber output data (input) # (example value: 5) # [ E_in_x [0:N], [ E_in_x [-5:N-5], [ E_in_x [0:N-5], # E_in_y [0:N], -> E_in_y [-5:N-5], -> E_in_y [0:N-5], # E_out_x[0:N], E_out_x[ 0:N ], E_out_x[5:N ], # E_out_y[0:N] ] E_out_y[ 0:N ] ] E_out_y[5:N ] ] if self.target_delay_samples != 0: data_raw = roll_along( data_raw, [self.target_delay_samples, self.target_delay_samples, 0, 0], dim=1, ) if self.target_delay_samples > 0: data_raw = data_raw[:, self.target_delay_samples :] elif self.target_delay_samples < 0: data_raw = data_raw[:, : self.target_delay_samples] data_raw = data_raw.view(2, 2, -1) # data layout # [ [E_in_x, E_in_y], # [E_out_x, E_out_y] ] self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1) self.data = self.data.movedim(-2, 0) # -> [no_slices, 2, 2, samples_per_slice] # data layout # [ # [ [E_in_x[0:N+0], E_in_y[0:N+0] ], [ E_out_x[0:N+0], E_out_y[0:N+0] ] ], # [ [E_in_x[1:N+1], E_in_y[1:N+1] ], [ E_out_x[1:N+1], E_out_y[1:N+1] ] ], # ... # ] -> [no_slices, 2, 2, samples_per_slice] ... def __len__(self): return self.data.shape[0] def __getitem__(self, idx): if isinstance(idx, slice): return [self.__getitem__(i) for i in range(*idx.indices(len(self)))] else: data, target = self.data[idx, 1].squeeze(), self.data[idx, 0].squeeze() # reduce by by taking self.output_dim equally spaced samples data = data[:, : data.shape[1] // self.output_dim * self.output_dim] data = data.view(data.shape[0], self.output_dim, -1) data = data[:, :, 0] # target is corresponding to the middle of the data as the output sample is influenced by the data before and after it target = target[:, : target.shape[1] // self.output_dim * self.output_dim] target = target.view(target.shape[0], self.output_dim, -1) target = target[:, 0, target.shape[2] // 2] data = data.transpose(0, 1).flatten().squeeze() target = target.flatten().squeeze() # data layout: # [sample_x0, sample_y0, sample_x1, sample_y1, ...] # target layout: # [sample_x0, sample_y0] return data, target