refactor complex loss functions for improved readability; update settings and dataset classes for consistency

This commit is contained in:
Joseph Hopfmüller
2024-11-24 01:55:32 +01:00
parent 9a16a5637d
commit 7343ccb3a5
4 changed files with 392 additions and 361 deletions

View File

@@ -1,6 +1,7 @@
from pathlib import Path
import torch
from torch.utils.data import Dataset
# from torch.utils.data import Sampler
import numpy as np
import configparser
@@ -22,6 +23,7 @@ import configparser
# def __len__(self):
# return len(self.indices)
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
filepath = Path(config_path)
filepath = filepath.parent.glob(filepath.name)
@@ -43,18 +45,19 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
if normalize:
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
a, b, c, d = np.square(data.T)
a, b, c, d = a/np.max(np.abs(a)), b/np.max(np.abs(b)), c/np.max(np.abs(c)), d/np.max(np.abs(d))
a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
data = np.sqrt(np.array([a, b, c, d]).T)
if real:
data = np.abs(data)
config["glova"]["nos"] = str(symbols)
data = torch.tensor(data, device=device, dtype=dtype)
return data, config
def roll_along(arr, shifts, dim):
# https://stackoverflow.com/a/76920720
# (c) Mateen Ulhaq, 2023
@@ -67,6 +70,7 @@ def roll_along(arr, shifts, dim):
indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim]
return torch.gather(arr, dim, indices)
class FiberRegenerationDataset(Dataset):
"""
Dataset for fiber regeneration training.
@@ -105,7 +109,7 @@ class FiberRegenerationDataset(Dataset):
drop_first: float | int = 0,
dtype: torch.dtype = None,
real: bool = False,
device = None,
device=None,
**kwargs,
):
"""
@@ -127,18 +131,10 @@ class FiberRegenerationDataset(Dataset):
# check types
assert isinstance(file_path, str), "file_path must be a string"
assert isinstance(symbols, (float, int)), (
"symbols must be a float or an integer"
)
assert output_dim is None or isinstance(output_dim, int), (
"output_len must be an integer"
)
assert isinstance(target_delay, (float, int)), (
"target_delay must be a float or an integer"
)
assert isinstance(xy_delay, (float, int)), (
"xy_delay must be a float or an integer"
)
assert isinstance(symbols, (float, int)), "symbols must be a float or an integer"
assert output_dim is None or isinstance(output_dim, int), "output_len must be an integer"
assert isinstance(target_delay, (float, int)), "target_delay must be a float or an integer"
assert isinstance(xy_delay, (float, int)), "xy_delay must be a float or an integer"
assert isinstance(drop_first, int), "drop_first must be an integer"
# check values
@@ -159,10 +155,18 @@ class FiberRegenerationDataset(Dataset):
"glova": {"sps": 128},
}
else:
data_raw, self.config = load_data(file_path, skipfirst=drop_first, symbols=kwargs.pop("num_symbols", None), real=real, normalize=True, device=device, dtype=dtype)
data_raw, self.config = load_data(
file_path,
skipfirst=drop_first,
symbols=kwargs.pop("num_symbols", None),
real=real,
normalize=True,
device=device,
dtype=dtype,
)
self.device = data_raw.device
self.samples_per_symbol = int(self.config["glova"]["sps"])
self.samples_per_slice = int(symbols * self.samples_per_symbol)
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
@@ -180,9 +184,7 @@ class FiberRegenerationDataset(Dataset):
else int(self.target_delay * self.samples_per_symbol)
)
self.xy_delay_samples = (
ovrd_xy_delay_samples
if ovrd_xy_delay_samples is not None
else int(self.xy_delay * self.samples_per_symbol)
ovrd_xy_delay_samples if ovrd_xy_delay_samples is not None else int(self.xy_delay * self.samples_per_symbol)
)
# data_raw = torch.tensor(data_raw, dtype=dtype)
@@ -190,15 +192,15 @@ class FiberRegenerationDataset(Dataset):
# data layout
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0],
# [E_in_x1, E_in_y1, E_out_x1, E_out_y1],
# ...
# ...
# [E_in_xN, E_in_yN, E_out_xN, E_out_yN] ]
data_raw = data_raw.transpose(0, 1)
# data layout
# [ E_in_x[0:N],
# E_in_y[0:N],
# E_out_x[0:N],
# [ E_in_x[0:N],
# E_in_y[0:N],
# E_out_x[0:N],
# E_out_y[0:N] ]
# shift x data by xy_delay_samples relative to the y data (example value: 3)
@@ -208,9 +210,7 @@ class FiberRegenerationDataset(Dataset):
# E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ]
if self.xy_delay_samples != 0:
data_raw = roll_along(
data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1
)
data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1)
if self.xy_delay_samples > 0:
data_raw = data_raw[:, self.xy_delay_samples :]
elif self.xy_delay_samples < 0: