From a5f2f493604af917330af41f8c543deecf82d151 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joseph=20Hopfm=C3=BCller?= Date: Mon, 2 Dec 2024 18:49:14 +0100 Subject: [PATCH] enhance data loading and processing in FiberRegenerationDataset; add timestamps and support for multiple file paths --- src/single-core-regen/util/datasets.py | 120 +++++++++++++++---------- 1 file changed, 74 insertions(+), 46 deletions(-) diff --git a/src/single-core-regen/util/datasets.py b/src/single-core-regen/util/datasets.py index 0aac64e..98e7781 100644 --- a/src/single-core-regen/util/datasets.py +++ b/src/single-core-regen/util/datasets.py @@ -40,7 +40,8 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals if symbols is None: symbols = int(config["glova"]["nos"]) - skipfirst - data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps] + data = np.load(datapath)[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)] + timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps)) if normalize: # square gets normalized to 1, as the power is (proportional to) the square of the amplitude @@ -53,6 +54,8 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals config["glova"]["nos"] = str(symbols) + data = np.concatenate([data, timestamps.reshape(-1,1)], axis=-1) + data = torch.tensor(data, device=device, dtype=dtype) return data, config @@ -100,7 +103,7 @@ class FiberRegenerationDataset(Dataset): def __init__( self, - file_path: str | Path, + file_path: tuple | list | str | Path, symbols: int | float, *, output_dim: int = None, @@ -130,12 +133,12 @@ class FiberRegenerationDataset(Dataset): """ # check types - assert isinstance(file_path, str), "file_path must be a string" + assert isinstance(file_path, (str, Path, tuple, list)), "file_path must be a string, Path, tuple, or list" assert isinstance(symbols, (float, int)), "symbols must be a float or an integer" assert output_dim is None or isinstance(output_dim, int), "output_len must be an integer" assert isinstance(target_delay, (float, int)), "target_delay must be a float or an integer" assert isinstance(xy_delay, (float, int)), "xy_delay must be a float or an integer" - assert isinstance(drop_first, int), "drop_first must be an integer" + # assert isinstance(drop_first, int), "drop_first must be an integer" # check values assert symbols > 0, "symbols must be positive" @@ -150,21 +153,39 @@ class FiberRegenerationDataset(Dataset): dtype=np.complex128, ) data_raw = torch.tensor(data_raw, device=device, dtype=dtype) + timestamps = torch.arange(12800) + + data_raw = torch.concatenate([data_raw, timestamps.reshape(-1, 1)], axis=-1) + self.config = { "data": {"dir": '"."', "npy_dir": '"."', "file": "faux"}, "glova": {"sps": 128}, } else: - data_raw, self.config = load_data( - file_path, - skipfirst=drop_first, - symbols=kwargs.pop("num_symbols", None), - real=real, - normalize=True, - device=device, - dtype=dtype, - ) - + data_raw = None + self.config = None + files = [] + for file_path in (file_path if isinstance(file_path, (tuple, list)) else [file_path]): + data, config = load_data( + file_path, + skipfirst=drop_first, + symbols=kwargs.get("num_symbols", None), + real=real, + normalize=True, + device=device, + dtype=dtype, + ) + if data_raw is None: + data_raw = data + else: + data_raw = torch.cat([data_raw, data], dim=0) + if self.config is None: + self.config = config + else: + assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same" + files.append(config["data"]["file"].strip('"')) + self.config["data"]["file"] = str(files) + self.device = data_raw.device self.samples_per_symbol = int(self.config["glova"]["sps"]) @@ -190,27 +211,29 @@ class FiberRegenerationDataset(Dataset): # data_raw = torch.tensor(data_raw, dtype=dtype) # data layout - # [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0], - # [E_in_x1, E_in_y1, E_out_x1, E_out_y1], + # [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0, timestamp0], + # [E_in_x1, E_in_y1, E_out_x1, E_out_y1, timestamp1], # ... - # [E_in_xN, E_in_yN, E_out_xN, E_out_yN] ] + # [E_in_xN, E_in_yN, E_out_xN, E_out_yN, timestampN] ] data_raw = data_raw.transpose(0, 1) # data layout - # [ E_in_x[0:N], - # E_in_y[0:N], - # E_out_x[0:N], - # E_out_y[0:N] ] + # [ E_in_x[0:N], + # E_in_y[0:N], + # E_out_x[0:N], + # E_out_y[0:N], + # timestamps[0:N] ] # shift x data by xy_delay_samples relative to the y data (example value: 3) # [ E_in_x [0:N], [ E_in_x [ 0:N ], [ E_in_x [3:N ], # E_in_y [0:N], -> E_in_y [-3:N-3], -> E_in_y [0:N-3], # E_out_x[0:N], E_out_x[ 0:N ], E_out_x[3:N ], - # E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ] + # E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3], + # timestamps[0:N] ] timestamps[ 0:N ] ] timestamps[3:N ] ] if self.xy_delay_samples != 0: - data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1) + data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples, 0], dim=1) if self.xy_delay_samples > 0: data_raw = data_raw[:, self.xy_delay_samples :] elif self.xy_delay_samples < 0: @@ -221,12 +244,13 @@ class FiberRegenerationDataset(Dataset): # [ E_in_x [0:N], [ E_in_x [-5:N-5], [ E_in_x [0:N-5], # E_in_y [0:N], -> E_in_y [-5:N-5], -> E_in_y [0:N-5], # E_out_x[0:N], E_out_x[ 0:N ], E_out_x[5:N ], - # E_out_y[0:N] ] E_out_y[ 0:N ] ] E_out_y[5:N ] ] + # E_out_y[0:N] ] E_out_y[ 0:N ] ] E_out_y[5:N ], + # timestamps[0:N] ] timestamps[ 0:N ] ] timestamps[5:N ] if self.target_delay_samples != 0: data_raw = roll_along( data_raw, - [self.target_delay_samples, self.target_delay_samples, 0, 0], + [self.target_delay_samples, self.target_delay_samples, 0, 0, 0], dim=1, ) if self.target_delay_samples > 0: @@ -234,21 +258,25 @@ class FiberRegenerationDataset(Dataset): elif self.target_delay_samples < 0: data_raw = data_raw[:, : self.target_delay_samples] + timestamps = data_raw[-1, :] + data_raw = data_raw[:-1, :] data_raw = data_raw.view(2, 2, -1) + timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(dim=1) + data_raw = torch.cat([data_raw, timestamps_doubled], dim=1) # data layout - # [ [E_in_x, E_in_y], - # [E_out_x, E_out_y] ] + # [ [E_in_x, E_in_y, timestamps], + # [E_out_x, E_out_y, timestamps] ] self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1) self.data = self.data.movedim(-2, 0) - # -> [no_slices, 2, 2, samples_per_slice] + # -> [no_slices, 2, 3, samples_per_slice] # data layout # [ - # [ [E_in_x[0:N+0], E_in_y[0:N+0] ], [ E_out_x[0:N+0], E_out_y[0:N+0] ] ], - # [ [E_in_x[1:N+1], E_in_y[1:N+1] ], [ E_out_x[1:N+1], E_out_y[1:N+1] ] ], + # [ [E_in_x[0:N+0], E_in_y[0:N+0], timestamps[0:N+0]], [ E_out_x[0:N+0], E_out_y[0:N+0], timestamps[0:N+0] ] ], + # [ [E_in_x[1:N+1], E_in_y[1:N+1], timestamps[1:N+1]], [ E_out_x[1:N+1], E_out_y[1:N+1], timestamps[1:N+1] ] ], # ... - # ] -> [no_slices, 2, 2, samples_per_slice] + # ] -> [no_slices, 2, 3, samples_per_slice] ... @@ -259,24 +287,24 @@ class FiberRegenerationDataset(Dataset): if isinstance(idx, slice): return [self.__getitem__(i) for i in range(*idx.indices(len(self)))] else: - data, target = self.data[idx, 1].squeeze(), self.data[idx, 0].squeeze() + data_slice = self.data[idx].squeeze() - # reduce by by taking self.output_dim equally spaced samples - data = data[:, : data.shape[1] // self.output_dim * self.output_dim] - data = data.view(data.shape[0], self.output_dim, -1) - data = data[:, :, 0] + data_slice = data_slice[:, :, :data_slice.shape[2] // self.output_dim * self.output_dim] + + data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1) + + target = data_slice[0, :, self.output_dim//2, 0] + data = data_slice[1, :, :, 0] + + # data_timestamps = data[-1,:].real + data = data[:-1, :] + target_timestamp = target[-1].real + target = target[:-1] - # target is corresponding to the middle of the data as the output sample is influenced by the data before and after it - target = target[:, : target.shape[1] // self.output_dim * self.output_dim] - target = target.view(target.shape[0], self.output_dim, -1) - target = target[:, 0, target.shape[2] // 2] data = data.transpose(0, 1).flatten().squeeze() + # data_timestamps = data_timestamps.flatten().squeeze() target = target.flatten().squeeze() + target_timestamp = target_timestamp.flatten().squeeze() - # data layout: - # [sample_x0, sample_y0, sample_x1, sample_y1, ...] - # target layout: - # [sample_x0, sample_y0] - - return data, target + return data, target, target_timestamp