enhance data loading and processing in FiberRegenerationDataset; add timestamps and support for multiple file paths
This commit is contained in:
@@ -40,7 +40,8 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
|
|||||||
if symbols is None:
|
if symbols is None:
|
||||||
symbols = int(config["glova"]["nos"]) - skipfirst
|
symbols = int(config["glova"]["nos"]) - skipfirst
|
||||||
|
|
||||||
data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps]
|
data = np.load(datapath)[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)]
|
||||||
|
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps))
|
||||||
|
|
||||||
if normalize:
|
if normalize:
|
||||||
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
||||||
@@ -53,6 +54,8 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
|
|||||||
|
|
||||||
config["glova"]["nos"] = str(symbols)
|
config["glova"]["nos"] = str(symbols)
|
||||||
|
|
||||||
|
data = np.concatenate([data, timestamps.reshape(-1,1)], axis=-1)
|
||||||
|
|
||||||
data = torch.tensor(data, device=device, dtype=dtype)
|
data = torch.tensor(data, device=device, dtype=dtype)
|
||||||
|
|
||||||
return data, config
|
return data, config
|
||||||
@@ -100,7 +103,7 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
file_path: str | Path,
|
file_path: tuple | list | str | Path,
|
||||||
symbols: int | float,
|
symbols: int | float,
|
||||||
*,
|
*,
|
||||||
output_dim: int = None,
|
output_dim: int = None,
|
||||||
@@ -130,12 +133,12 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# check types
|
# check types
|
||||||
assert isinstance(file_path, str), "file_path must be a string"
|
assert isinstance(file_path, (str, Path, tuple, list)), "file_path must be a string, Path, tuple, or list"
|
||||||
assert isinstance(symbols, (float, int)), "symbols must be a float or an integer"
|
assert isinstance(symbols, (float, int)), "symbols must be a float or an integer"
|
||||||
assert output_dim is None or isinstance(output_dim, int), "output_len must be an integer"
|
assert output_dim is None or isinstance(output_dim, int), "output_len must be an integer"
|
||||||
assert isinstance(target_delay, (float, int)), "target_delay must be a float or an integer"
|
assert isinstance(target_delay, (float, int)), "target_delay must be a float or an integer"
|
||||||
assert isinstance(xy_delay, (float, int)), "xy_delay must be a float or an integer"
|
assert isinstance(xy_delay, (float, int)), "xy_delay must be a float or an integer"
|
||||||
assert isinstance(drop_first, int), "drop_first must be an integer"
|
# assert isinstance(drop_first, int), "drop_first must be an integer"
|
||||||
|
|
||||||
# check values
|
# check values
|
||||||
assert symbols > 0, "symbols must be positive"
|
assert symbols > 0, "symbols must be positive"
|
||||||
@@ -150,21 +153,39 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
dtype=np.complex128,
|
dtype=np.complex128,
|
||||||
)
|
)
|
||||||
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
|
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
|
||||||
|
timestamps = torch.arange(12800)
|
||||||
|
|
||||||
|
data_raw = torch.concatenate([data_raw, timestamps.reshape(-1, 1)], axis=-1)
|
||||||
|
|
||||||
self.config = {
|
self.config = {
|
||||||
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
|
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
|
||||||
"glova": {"sps": 128},
|
"glova": {"sps": 128},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
data_raw, self.config = load_data(
|
data_raw = None
|
||||||
file_path,
|
self.config = None
|
||||||
skipfirst=drop_first,
|
files = []
|
||||||
symbols=kwargs.pop("num_symbols", None),
|
for file_path in (file_path if isinstance(file_path, (tuple, list)) else [file_path]):
|
||||||
real=real,
|
data, config = load_data(
|
||||||
normalize=True,
|
file_path,
|
||||||
device=device,
|
skipfirst=drop_first,
|
||||||
dtype=dtype,
|
symbols=kwargs.get("num_symbols", None),
|
||||||
)
|
real=real,
|
||||||
|
normalize=True,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
if data_raw is None:
|
||||||
|
data_raw = data
|
||||||
|
else:
|
||||||
|
data_raw = torch.cat([data_raw, data], dim=0)
|
||||||
|
if self.config is None:
|
||||||
|
self.config = config
|
||||||
|
else:
|
||||||
|
assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same"
|
||||||
|
files.append(config["data"]["file"].strip('"'))
|
||||||
|
self.config["data"]["file"] = str(files)
|
||||||
|
|
||||||
self.device = data_raw.device
|
self.device = data_raw.device
|
||||||
|
|
||||||
self.samples_per_symbol = int(self.config["glova"]["sps"])
|
self.samples_per_symbol = int(self.config["glova"]["sps"])
|
||||||
@@ -190,27 +211,29 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
# data_raw = torch.tensor(data_raw, dtype=dtype)
|
# data_raw = torch.tensor(data_raw, dtype=dtype)
|
||||||
|
|
||||||
# data layout
|
# data layout
|
||||||
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0],
|
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0, timestamp0],
|
||||||
# [E_in_x1, E_in_y1, E_out_x1, E_out_y1],
|
# [E_in_x1, E_in_y1, E_out_x1, E_out_y1, timestamp1],
|
||||||
# ...
|
# ...
|
||||||
# [E_in_xN, E_in_yN, E_out_xN, E_out_yN] ]
|
# [E_in_xN, E_in_yN, E_out_xN, E_out_yN, timestampN] ]
|
||||||
|
|
||||||
data_raw = data_raw.transpose(0, 1)
|
data_raw = data_raw.transpose(0, 1)
|
||||||
|
|
||||||
# data layout
|
# data layout
|
||||||
# [ E_in_x[0:N],
|
# [ E_in_x[0:N],
|
||||||
# E_in_y[0:N],
|
# E_in_y[0:N],
|
||||||
# E_out_x[0:N],
|
# E_out_x[0:N],
|
||||||
# E_out_y[0:N] ]
|
# E_out_y[0:N],
|
||||||
|
# timestamps[0:N] ]
|
||||||
|
|
||||||
# shift x data by xy_delay_samples relative to the y data (example value: 3)
|
# shift x data by xy_delay_samples relative to the y data (example value: 3)
|
||||||
# [ E_in_x [0:N], [ E_in_x [ 0:N ], [ E_in_x [3:N ],
|
# [ E_in_x [0:N], [ E_in_x [ 0:N ], [ E_in_x [3:N ],
|
||||||
# E_in_y [0:N], -> E_in_y [-3:N-3], -> E_in_y [0:N-3],
|
# E_in_y [0:N], -> E_in_y [-3:N-3], -> E_in_y [0:N-3],
|
||||||
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[3:N ],
|
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[3:N ],
|
||||||
# E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ]
|
# E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3],
|
||||||
|
# timestamps[0:N] ] timestamps[ 0:N ] ] timestamps[3:N ] ]
|
||||||
|
|
||||||
if self.xy_delay_samples != 0:
|
if self.xy_delay_samples != 0:
|
||||||
data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1)
|
data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples, 0], dim=1)
|
||||||
if self.xy_delay_samples > 0:
|
if self.xy_delay_samples > 0:
|
||||||
data_raw = data_raw[:, self.xy_delay_samples :]
|
data_raw = data_raw[:, self.xy_delay_samples :]
|
||||||
elif self.xy_delay_samples < 0:
|
elif self.xy_delay_samples < 0:
|
||||||
@@ -221,12 +244,13 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
# [ E_in_x [0:N], [ E_in_x [-5:N-5], [ E_in_x [0:N-5],
|
# [ E_in_x [0:N], [ E_in_x [-5:N-5], [ E_in_x [0:N-5],
|
||||||
# E_in_y [0:N], -> E_in_y [-5:N-5], -> E_in_y [0:N-5],
|
# E_in_y [0:N], -> E_in_y [-5:N-5], -> E_in_y [0:N-5],
|
||||||
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[5:N ],
|
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[5:N ],
|
||||||
# E_out_y[0:N] ] E_out_y[ 0:N ] ] E_out_y[5:N ] ]
|
# E_out_y[0:N] ] E_out_y[ 0:N ] ] E_out_y[5:N ],
|
||||||
|
# timestamps[0:N] ] timestamps[ 0:N ] ] timestamps[5:N ]
|
||||||
|
|
||||||
if self.target_delay_samples != 0:
|
if self.target_delay_samples != 0:
|
||||||
data_raw = roll_along(
|
data_raw = roll_along(
|
||||||
data_raw,
|
data_raw,
|
||||||
[self.target_delay_samples, self.target_delay_samples, 0, 0],
|
[self.target_delay_samples, self.target_delay_samples, 0, 0, 0],
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
if self.target_delay_samples > 0:
|
if self.target_delay_samples > 0:
|
||||||
@@ -234,21 +258,25 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
elif self.target_delay_samples < 0:
|
elif self.target_delay_samples < 0:
|
||||||
data_raw = data_raw[:, : self.target_delay_samples]
|
data_raw = data_raw[:, : self.target_delay_samples]
|
||||||
|
|
||||||
|
timestamps = data_raw[-1, :]
|
||||||
|
data_raw = data_raw[:-1, :]
|
||||||
data_raw = data_raw.view(2, 2, -1)
|
data_raw = data_raw.view(2, 2, -1)
|
||||||
|
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(dim=1)
|
||||||
|
data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
||||||
# data layout
|
# data layout
|
||||||
# [ [E_in_x, E_in_y],
|
# [ [E_in_x, E_in_y, timestamps],
|
||||||
# [E_out_x, E_out_y] ]
|
# [E_out_x, E_out_y, timestamps] ]
|
||||||
|
|
||||||
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||||
self.data = self.data.movedim(-2, 0)
|
self.data = self.data.movedim(-2, 0)
|
||||||
# -> [no_slices, 2, 2, samples_per_slice]
|
# -> [no_slices, 2, 3, samples_per_slice]
|
||||||
|
|
||||||
# data layout
|
# data layout
|
||||||
# [
|
# [
|
||||||
# [ [E_in_x[0:N+0], E_in_y[0:N+0] ], [ E_out_x[0:N+0], E_out_y[0:N+0] ] ],
|
# [ [E_in_x[0:N+0], E_in_y[0:N+0], timestamps[0:N+0]], [ E_out_x[0:N+0], E_out_y[0:N+0], timestamps[0:N+0] ] ],
|
||||||
# [ [E_in_x[1:N+1], E_in_y[1:N+1] ], [ E_out_x[1:N+1], E_out_y[1:N+1] ] ],
|
# [ [E_in_x[1:N+1], E_in_y[1:N+1], timestamps[1:N+1]], [ E_out_x[1:N+1], E_out_y[1:N+1], timestamps[1:N+1] ] ],
|
||||||
# ...
|
# ...
|
||||||
# ] -> [no_slices, 2, 2, samples_per_slice]
|
# ] -> [no_slices, 2, 3, samples_per_slice]
|
||||||
|
|
||||||
...
|
...
|
||||||
|
|
||||||
@@ -259,24 +287,24 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
if isinstance(idx, slice):
|
if isinstance(idx, slice):
|
||||||
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
|
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
|
||||||
else:
|
else:
|
||||||
data, target = self.data[idx, 1].squeeze(), self.data[idx, 0].squeeze()
|
data_slice = self.data[idx].squeeze()
|
||||||
|
|
||||||
# reduce by by taking self.output_dim equally spaced samples
|
data_slice = data_slice[:, :, :data_slice.shape[2] // self.output_dim * self.output_dim]
|
||||||
data = data[:, : data.shape[1] // self.output_dim * self.output_dim]
|
|
||||||
data = data.view(data.shape[0], self.output_dim, -1)
|
data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
|
||||||
data = data[:, :, 0]
|
|
||||||
|
target = data_slice[0, :, self.output_dim//2, 0]
|
||||||
|
data = data_slice[1, :, :, 0]
|
||||||
|
|
||||||
|
# data_timestamps = data[-1,:].real
|
||||||
|
data = data[:-1, :]
|
||||||
|
target_timestamp = target[-1].real
|
||||||
|
target = target[:-1]
|
||||||
|
|
||||||
# target is corresponding to the middle of the data as the output sample is influenced by the data before and after it
|
|
||||||
target = target[:, : target.shape[1] // self.output_dim * self.output_dim]
|
|
||||||
target = target.view(target.shape[0], self.output_dim, -1)
|
|
||||||
target = target[:, 0, target.shape[2] // 2]
|
|
||||||
|
|
||||||
data = data.transpose(0, 1).flatten().squeeze()
|
data = data.transpose(0, 1).flatten().squeeze()
|
||||||
|
# data_timestamps = data_timestamps.flatten().squeeze()
|
||||||
target = target.flatten().squeeze()
|
target = target.flatten().squeeze()
|
||||||
|
target_timestamp = target_timestamp.flatten().squeeze()
|
||||||
|
|
||||||
# data layout:
|
return data, target, target_timestamp
|
||||||
# [sample_x0, sample_y0, sample_x1, sample_y1, ...]
|
|
||||||
# target layout:
|
|
||||||
# [sample_x0, sample_y0]
|
|
||||||
|
|
||||||
return data, target
|
|
||||||
|
|||||||
Reference in New Issue
Block a user