enhance data loading and processing in FiberRegenerationDataset; add timestamps and support for multiple file paths

This commit is contained in:
Joseph Hopfmüller
2024-12-02 18:49:14 +01:00
parent e20aa9bfb1
commit a5f2f49360

View File

@@ -40,7 +40,8 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
if symbols is None:
symbols = int(config["glova"]["nos"]) - skipfirst
data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps]
data = np.load(datapath)[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)]
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps))
if normalize:
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
@@ -53,6 +54,8 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
config["glova"]["nos"] = str(symbols)
data = np.concatenate([data, timestamps.reshape(-1,1)], axis=-1)
data = torch.tensor(data, device=device, dtype=dtype)
return data, config
@@ -100,7 +103,7 @@ class FiberRegenerationDataset(Dataset):
def __init__(
self,
file_path: str | Path,
file_path: tuple | list | str | Path,
symbols: int | float,
*,
output_dim: int = None,
@@ -130,12 +133,12 @@ class FiberRegenerationDataset(Dataset):
"""
# check types
assert isinstance(file_path, str), "file_path must be a string"
assert isinstance(file_path, (str, Path, tuple, list)), "file_path must be a string, Path, tuple, or list"
assert isinstance(symbols, (float, int)), "symbols must be a float or an integer"
assert output_dim is None or isinstance(output_dim, int), "output_len must be an integer"
assert isinstance(target_delay, (float, int)), "target_delay must be a float or an integer"
assert isinstance(xy_delay, (float, int)), "xy_delay must be a float or an integer"
assert isinstance(drop_first, int), "drop_first must be an integer"
# assert isinstance(drop_first, int), "drop_first must be an integer"
# check values
assert symbols > 0, "symbols must be positive"
@@ -150,20 +153,38 @@ class FiberRegenerationDataset(Dataset):
dtype=np.complex128,
)
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
timestamps = torch.arange(12800)
data_raw = torch.concatenate([data_raw, timestamps.reshape(-1, 1)], axis=-1)
self.config = {
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
"glova": {"sps": 128},
}
else:
data_raw, self.config = load_data(
file_path,
skipfirst=drop_first,
symbols=kwargs.pop("num_symbols", None),
real=real,
normalize=True,
device=device,
dtype=dtype,
)
data_raw = None
self.config = None
files = []
for file_path in (file_path if isinstance(file_path, (tuple, list)) else [file_path]):
data, config = load_data(
file_path,
skipfirst=drop_first,
symbols=kwargs.get("num_symbols", None),
real=real,
normalize=True,
device=device,
dtype=dtype,
)
if data_raw is None:
data_raw = data
else:
data_raw = torch.cat([data_raw, data], dim=0)
if self.config is None:
self.config = config
else:
assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same"
files.append(config["data"]["file"].strip('"'))
self.config["data"]["file"] = str(files)
self.device = data_raw.device
@@ -190,27 +211,29 @@ class FiberRegenerationDataset(Dataset):
# data_raw = torch.tensor(data_raw, dtype=dtype)
# data layout
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0],
# [E_in_x1, E_in_y1, E_out_x1, E_out_y1],
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0, timestamp0],
# [E_in_x1, E_in_y1, E_out_x1, E_out_y1, timestamp1],
# ...
# [E_in_xN, E_in_yN, E_out_xN, E_out_yN] ]
# [E_in_xN, E_in_yN, E_out_xN, E_out_yN, timestampN] ]
data_raw = data_raw.transpose(0, 1)
# data layout
# [ E_in_x[0:N],
# E_in_y[0:N],
# E_out_x[0:N],
# E_out_y[0:N] ]
# [ E_in_x[0:N],
# E_in_y[0:N],
# E_out_x[0:N],
# E_out_y[0:N],
# timestamps[0:N] ]
# shift x data by xy_delay_samples relative to the y data (example value: 3)
# [ E_in_x [0:N], [ E_in_x [ 0:N ], [ E_in_x [3:N ],
# E_in_y [0:N], -> E_in_y [-3:N-3], -> E_in_y [0:N-3],
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[3:N ],
# E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ]
# E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3],
# timestamps[0:N] ] timestamps[ 0:N ] ] timestamps[3:N ] ]
if self.xy_delay_samples != 0:
data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1)
data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples, 0], dim=1)
if self.xy_delay_samples > 0:
data_raw = data_raw[:, self.xy_delay_samples :]
elif self.xy_delay_samples < 0:
@@ -221,12 +244,13 @@ class FiberRegenerationDataset(Dataset):
# [ E_in_x [0:N], [ E_in_x [-5:N-5], [ E_in_x [0:N-5],
# E_in_y [0:N], -> E_in_y [-5:N-5], -> E_in_y [0:N-5],
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[5:N ],
# E_out_y[0:N] ] E_out_y[ 0:N ] ] E_out_y[5:N ] ]
# E_out_y[0:N] ] E_out_y[ 0:N ] ] E_out_y[5:N ],
# timestamps[0:N] ] timestamps[ 0:N ] ] timestamps[5:N ]
if self.target_delay_samples != 0:
data_raw = roll_along(
data_raw,
[self.target_delay_samples, self.target_delay_samples, 0, 0],
[self.target_delay_samples, self.target_delay_samples, 0, 0, 0],
dim=1,
)
if self.target_delay_samples > 0:
@@ -234,21 +258,25 @@ class FiberRegenerationDataset(Dataset):
elif self.target_delay_samples < 0:
data_raw = data_raw[:, : self.target_delay_samples]
timestamps = data_raw[-1, :]
data_raw = data_raw[:-1, :]
data_raw = data_raw.view(2, 2, -1)
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(dim=1)
data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
# data layout
# [ [E_in_x, E_in_y],
# [E_out_x, E_out_y] ]
# [ [E_in_x, E_in_y, timestamps],
# [E_out_x, E_out_y, timestamps] ]
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.data = self.data.movedim(-2, 0)
# -> [no_slices, 2, 2, samples_per_slice]
# -> [no_slices, 2, 3, samples_per_slice]
# data layout
# [
# [ [E_in_x[0:N+0], E_in_y[0:N+0] ], [ E_out_x[0:N+0], E_out_y[0:N+0] ] ],
# [ [E_in_x[1:N+1], E_in_y[1:N+1] ], [ E_out_x[1:N+1], E_out_y[1:N+1] ] ],
# [ [E_in_x[0:N+0], E_in_y[0:N+0], timestamps[0:N+0]], [ E_out_x[0:N+0], E_out_y[0:N+0], timestamps[0:N+0] ] ],
# [ [E_in_x[1:N+1], E_in_y[1:N+1], timestamps[1:N+1]], [ E_out_x[1:N+1], E_out_y[1:N+1], timestamps[1:N+1] ] ],
# ...
# ] -> [no_slices, 2, 2, samples_per_slice]
# ] -> [no_slices, 2, 3, samples_per_slice]
...
@@ -259,24 +287,24 @@ class FiberRegenerationDataset(Dataset):
if isinstance(idx, slice):
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
else:
data, target = self.data[idx, 1].squeeze(), self.data[idx, 0].squeeze()
data_slice = self.data[idx].squeeze()
# reduce by by taking self.output_dim equally spaced samples
data = data[:, : data.shape[1] // self.output_dim * self.output_dim]
data = data.view(data.shape[0], self.output_dim, -1)
data = data[:, :, 0]
data_slice = data_slice[:, :, :data_slice.shape[2] // self.output_dim * self.output_dim]
data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
target = data_slice[0, :, self.output_dim//2, 0]
data = data_slice[1, :, :, 0]
# data_timestamps = data[-1,:].real
data = data[:-1, :]
target_timestamp = target[-1].real
target = target[:-1]
# target is corresponding to the middle of the data as the output sample is influenced by the data before and after it
target = target[:, : target.shape[1] // self.output_dim * self.output_dim]
target = target.view(target.shape[0], self.output_dim, -1)
target = target[:, 0, target.shape[2] // 2]
data = data.transpose(0, 1).flatten().squeeze()
# data_timestamps = data_timestamps.flatten().squeeze()
target = target.flatten().squeeze()
target_timestamp = target_timestamp.flatten().squeeze()
# data layout:
# [sample_x0, sample_y0, sample_x1, sample_y1, ...]
# target layout:
# [sample_x0, sample_y0]
return data, target
return data, target, target_timestamp