refactor complex loss functions for improved readability; update settings and dataset classes for consistency
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
# from torch.utils.data import Sampler
|
||||
import numpy as np
|
||||
import configparser
|
||||
@@ -22,6 +23,7 @@ import configparser
|
||||
# def __len__(self):
|
||||
# return len(self.indices)
|
||||
|
||||
|
||||
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
|
||||
filepath = Path(config_path)
|
||||
filepath = filepath.parent.glob(filepath.name)
|
||||
@@ -43,18 +45,19 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
|
||||
if normalize:
|
||||
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
||||
a, b, c, d = np.square(data.T)
|
||||
a, b, c, d = a/np.max(np.abs(a)), b/np.max(np.abs(b)), c/np.max(np.abs(c)), d/np.max(np.abs(d))
|
||||
a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
|
||||
data = np.sqrt(np.array([a, b, c, d]).T)
|
||||
|
||||
if real:
|
||||
data = np.abs(data)
|
||||
|
||||
|
||||
config["glova"]["nos"] = str(symbols)
|
||||
|
||||
data = torch.tensor(data, device=device, dtype=dtype)
|
||||
|
||||
return data, config
|
||||
|
||||
|
||||
def roll_along(arr, shifts, dim):
|
||||
# https://stackoverflow.com/a/76920720
|
||||
# (c) Mateen Ulhaq, 2023
|
||||
@@ -67,6 +70,7 @@ def roll_along(arr, shifts, dim):
|
||||
indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim]
|
||||
return torch.gather(arr, dim, indices)
|
||||
|
||||
|
||||
class FiberRegenerationDataset(Dataset):
|
||||
"""
|
||||
Dataset for fiber regeneration training.
|
||||
@@ -105,7 +109,7 @@ class FiberRegenerationDataset(Dataset):
|
||||
drop_first: float | int = 0,
|
||||
dtype: torch.dtype = None,
|
||||
real: bool = False,
|
||||
device = None,
|
||||
device=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -127,18 +131,10 @@ class FiberRegenerationDataset(Dataset):
|
||||
|
||||
# check types
|
||||
assert isinstance(file_path, str), "file_path must be a string"
|
||||
assert isinstance(symbols, (float, int)), (
|
||||
"symbols must be a float or an integer"
|
||||
)
|
||||
assert output_dim is None or isinstance(output_dim, int), (
|
||||
"output_len must be an integer"
|
||||
)
|
||||
assert isinstance(target_delay, (float, int)), (
|
||||
"target_delay must be a float or an integer"
|
||||
)
|
||||
assert isinstance(xy_delay, (float, int)), (
|
||||
"xy_delay must be a float or an integer"
|
||||
)
|
||||
assert isinstance(symbols, (float, int)), "symbols must be a float or an integer"
|
||||
assert output_dim is None or isinstance(output_dim, int), "output_len must be an integer"
|
||||
assert isinstance(target_delay, (float, int)), "target_delay must be a float or an integer"
|
||||
assert isinstance(xy_delay, (float, int)), "xy_delay must be a float or an integer"
|
||||
assert isinstance(drop_first, int), "drop_first must be an integer"
|
||||
|
||||
# check values
|
||||
@@ -159,10 +155,18 @@ class FiberRegenerationDataset(Dataset):
|
||||
"glova": {"sps": 128},
|
||||
}
|
||||
else:
|
||||
data_raw, self.config = load_data(file_path, skipfirst=drop_first, symbols=kwargs.pop("num_symbols", None), real=real, normalize=True, device=device, dtype=dtype)
|
||||
data_raw, self.config = load_data(
|
||||
file_path,
|
||||
skipfirst=drop_first,
|
||||
symbols=kwargs.pop("num_symbols", None),
|
||||
real=real,
|
||||
normalize=True,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
self.device = data_raw.device
|
||||
|
||||
|
||||
self.samples_per_symbol = int(self.config["glova"]["sps"])
|
||||
self.samples_per_slice = int(symbols * self.samples_per_symbol)
|
||||
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
|
||||
@@ -180,9 +184,7 @@ class FiberRegenerationDataset(Dataset):
|
||||
else int(self.target_delay * self.samples_per_symbol)
|
||||
)
|
||||
self.xy_delay_samples = (
|
||||
ovrd_xy_delay_samples
|
||||
if ovrd_xy_delay_samples is not None
|
||||
else int(self.xy_delay * self.samples_per_symbol)
|
||||
ovrd_xy_delay_samples if ovrd_xy_delay_samples is not None else int(self.xy_delay * self.samples_per_symbol)
|
||||
)
|
||||
|
||||
# data_raw = torch.tensor(data_raw, dtype=dtype)
|
||||
@@ -190,15 +192,15 @@ class FiberRegenerationDataset(Dataset):
|
||||
# data layout
|
||||
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0],
|
||||
# [E_in_x1, E_in_y1, E_out_x1, E_out_y1],
|
||||
# ...
|
||||
# ...
|
||||
# [E_in_xN, E_in_yN, E_out_xN, E_out_yN] ]
|
||||
|
||||
data_raw = data_raw.transpose(0, 1)
|
||||
|
||||
# data layout
|
||||
# [ E_in_x[0:N],
|
||||
# E_in_y[0:N],
|
||||
# E_out_x[0:N],
|
||||
# [ E_in_x[0:N],
|
||||
# E_in_y[0:N],
|
||||
# E_out_x[0:N],
|
||||
# E_out_y[0:N] ]
|
||||
|
||||
# shift x data by xy_delay_samples relative to the y data (example value: 3)
|
||||
@@ -208,9 +210,7 @@ class FiberRegenerationDataset(Dataset):
|
||||
# E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ]
|
||||
|
||||
if self.xy_delay_samples != 0:
|
||||
data_raw = roll_along(
|
||||
data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1
|
||||
)
|
||||
data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1)
|
||||
if self.xy_delay_samples > 0:
|
||||
data_raw = data_raw[:, self.xy_delay_samples :]
|
||||
elif self.xy_delay_samples < 0:
|
||||
|
||||
Reference in New Issue
Block a user