move hypertraining class into separate file;

move settings dataclasses into separate file;
add SemiUnitaryLayer;
clean up model response plotting code;
cnt hyperparameter search
This commit is contained in:
Joseph Hopfmüller
2024-11-20 22:49:31 +01:00
parent cdca5de473
commit 674033ac2e
11 changed files with 1064 additions and 553 deletions

View File

@@ -41,9 +41,10 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps]
if normalize:
a, b, c, d = data.T
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
a, b, c, d = np.square(data.T)
a, b, c, d = a/np.max(np.abs(a)), b/np.max(np.abs(b)), c/np.max(np.abs(c)), d/np.max(np.abs(d))
data = np.array([a, b, c, d]).T
data = np.sqrt(np.array([a, b, c, d]).T)
if real:
data = np.abs(data)
@@ -98,7 +99,7 @@ class FiberRegenerationDataset(Dataset):
file_path: str | Path,
symbols: int | float,
*,
data_size: int = None,
output_dim: int = None,
target_delay: float | int = 0,
xy_delay: float | int = 0,
drop_first: float | int = 0,
@@ -129,7 +130,7 @@ class FiberRegenerationDataset(Dataset):
assert isinstance(symbols, (float, int)), (
"symbols must be a float or an integer"
)
assert data_size is None or isinstance(data_size, int), (
assert output_dim is None or isinstance(output_dim, int), (
"output_len must be an integer"
)
assert isinstance(target_delay, (float, int)), (
@@ -142,7 +143,7 @@ class FiberRegenerationDataset(Dataset):
# check values
assert symbols > 0, "symbols must be positive"
assert data_size is None or data_size > 0, "output_len must be positive or None"
assert output_dim is None or output_dim > 0, "output_len must be positive or None"
assert drop_first >= 0, "drop_first must be non-negative"
faux = kwargs.pop("faux", False)
@@ -158,7 +159,7 @@ class FiberRegenerationDataset(Dataset):
"glova": {"sps": 128},
}
else:
data_raw, self.config = load_data(file_path, skipfirst=drop_first, real=real, normalize=True, device=device, dtype=dtype)
data_raw, self.config = load_data(file_path, skipfirst=drop_first, symbols=kwargs.pop("num_symbols", None), real=real, normalize=True, device=device, dtype=dtype)
self.device = data_raw.device
@@ -166,7 +167,7 @@ class FiberRegenerationDataset(Dataset):
self.samples_per_slice = int(symbols * self.samples_per_symbol)
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
self.data_size = data_size or self.samples_per_slice
self.output_dim = output_dim or self.samples_per_slice
self.target_delay = target_delay or 0
self.xy_delay = xy_delay or 0
@@ -261,13 +262,13 @@ class FiberRegenerationDataset(Dataset):
data, target = self.data[idx, 1].squeeze(), self.data[idx, 0].squeeze()
# reduce by by taking self.output_dim equally spaced samples
data = data[:, : data.shape[1] // self.data_size * self.data_size]
data = data.view(data.shape[0], self.data_size, -1)
data = data[:, : data.shape[1] // self.output_dim * self.output_dim]
data = data.view(data.shape[0], self.output_dim, -1)
data = data[:, :, 0]
# target is corresponding to the middle of the data as the output sample is influenced by the data before and after it
target = target[:, : target.shape[1] // self.data_size * self.data_size]
target = target.view(target.shape[0], self.data_size, -1)
target = target[:, : target.shape[1] // self.output_dim * self.output_dim]
target = target.view(target.shape[0], self.output_dim, -1)
target = target[:, 0, target.shape[2] // 2]
data = data.transpose(0, 1).flatten().squeeze()