Compare commits
6 Commits
010889af13
...
a8a1c49c00
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a8a1c49c00 | ||
|
|
297e9e8d7f | ||
|
|
aa2e7a4cb4 | ||
|
|
1dcefecf59 | ||
|
|
a5f2f49360 | ||
|
|
e20aa9bfb1 |
1
.gitmodules
vendored
1
.gitmodules
vendored
@@ -1,3 +1,4 @@
|
|||||||
[submodule "pypho"]
|
[submodule "pypho"]
|
||||||
path = pypho
|
path = pypho
|
||||||
url = git@gitlab.lrz.de:000000003B9B3E61/pypho.git
|
url = git@gitlab.lrz.de:000000003B9B3E61/pypho.git
|
||||||
|
branch = main
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ flags = "FFTW_PATIENT"
|
|||||||
nthreads = 32
|
nthreads = 32
|
||||||
|
|
||||||
[fiber]
|
[fiber]
|
||||||
length = 80000
|
length = 10000
|
||||||
gamma = 1.14
|
gamma = 1.14
|
||||||
alpha = 0.2
|
alpha = 0.2
|
||||||
D = 17
|
D = 17
|
||||||
@@ -201,7 +201,7 @@ def initialize_fiber_and_data(config, input_data_override=None):
|
|||||||
"jitter_seed", (int(time.time() * 1000)) % 2**32
|
"jitter_seed", (int(time.time() * 1000)) % 2**32
|
||||||
)
|
)
|
||||||
symbolsrc = pypho.symbols(
|
symbolsrc = pypho.symbols(
|
||||||
py_glova, py_glova.nos, pattern="ones", seed=config["signal"]["seed"]
|
py_glova, py_glova.nos, pattern="random", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
|
||||||
)
|
)
|
||||||
laser = pypho.lasmod(
|
laser = pypho.lasmod(
|
||||||
py_glova, power=config["signal"]["laser_power"]+1.5, Df=0, theta=np.pi / 4
|
py_glova, power=config["signal"]["laser_power"]+1.5, Df=0, theta=np.pi / 4
|
||||||
@@ -214,8 +214,8 @@ def initialize_fiber_and_data(config, input_data_override=None):
|
|||||||
seed=config["signal"]["jitter_seed"],
|
seed=config["signal"]["jitter_seed"],
|
||||||
)
|
)
|
||||||
|
|
||||||
symbols_x = symbolsrc(pattern="random", p1=config["signal"]["mod_order"])
|
symbols_x = symbolsrc()
|
||||||
symbols_y = symbolsrc(pattern="random", p1=config["signal"]["mod_order"])
|
symbols_y = symbolsrc()
|
||||||
symbols_x[:3] = 0
|
symbols_x[:3] = 0
|
||||||
symbols_y[:3] = 0
|
symbols_y[:3] = 0
|
||||||
|
|
||||||
|
|||||||
@@ -258,12 +258,12 @@ class HyperTraining:
|
|||||||
f"model_hidden_dim_{i}",
|
f"model_hidden_dim_{i}",
|
||||||
self.model_settings.n_hidden_nodes,
|
self.model_settings.n_hidden_nodes,
|
||||||
)
|
)
|
||||||
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype))
|
layers.append(util.complexNN.ONNRect(last_dim, hidden_dim, dtype=dtype))
|
||||||
last_dim = hidden_dim
|
last_dim = hidden_dim
|
||||||
layers.append(getattr(util.complexNN, afunc)())
|
layers.append(getattr(util.complexNN, afunc)())
|
||||||
n_nodes += last_dim
|
n_nodes += last_dim
|
||||||
|
|
||||||
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, self.model_settings.output_dim, dtype=dtype))
|
layers.append(util.complexNN.ONNRect(last_dim, self.model_settings.output_dim, dtype=dtype))
|
||||||
|
|
||||||
model = nn.Sequential(*layers)
|
model = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class GlobalSettings:
|
|||||||
# data settings
|
# data settings
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataSettings:
|
class DataSettings:
|
||||||
config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
|
config_path: tuple | list | str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
|
||||||
dtype: tuple = ("complex64", "float64")
|
dtype: tuple = ("complex64", "float64")
|
||||||
symbols: tuple | float | int = 8
|
symbols: tuple | float | int = 8
|
||||||
output_size: tuple | float | int = 64
|
output_size: tuple | float | int = 64
|
||||||
@@ -39,7 +39,7 @@ class PytorchSettings:
|
|||||||
summary_dir: str = ".runs"
|
summary_dir: str = ".runs"
|
||||||
write_every: int = 10
|
write_every: int = 10
|
||||||
head_symbols: int = 40
|
head_symbols: int = 40
|
||||||
eye_symbols: int = 400
|
eye_symbols: int = 1000
|
||||||
|
|
||||||
|
|
||||||
# model settings
|
# model settings
|
||||||
@@ -52,13 +52,16 @@ class ModelSettings:
|
|||||||
overrides: dict = field(default_factory=dict)
|
overrides: dict = field(default_factory=dict)
|
||||||
dropout_prob: float | None = None
|
dropout_prob: float | None = None
|
||||||
model_layer_function: str | None = None
|
model_layer_function: str | None = None
|
||||||
|
scale: bool = False
|
||||||
|
model_layer_kwargs: dict | None = None
|
||||||
model_layer_parametrizations: list= field(default_factory=list)
|
model_layer_parametrizations: list= field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OptimizerSettings:
|
class OptimizerSettings:
|
||||||
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
|
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
|
||||||
learning_rate: tuple | float = (1e-5, 1e-1)
|
optimizer_kwargs: dict | None = None
|
||||||
|
# learning_rate: tuple | float = (1e-5, 1e-1)
|
||||||
scheduler: str | None = None
|
scheduler: str | None = None
|
||||||
scheduler_kwargs: dict | None = None
|
scheduler_kwargs: dict | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
import copy
|
import copy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import random
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
from matplotlib.colors import LinearSegmentedColormap
|
||||||
import torch.nn.utils.parametrize
|
import torch.nn.utils.parametrize
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -50,6 +52,7 @@ class regenerator(nn.Module):
|
|||||||
self,
|
self,
|
||||||
*dims,
|
*dims,
|
||||||
layer_function=util.complexNN.ONN,
|
layer_function=util.complexNN.ONN,
|
||||||
|
layer_kwargs: dict | None = None,
|
||||||
layer_parametrizations: list[dict] = None,
|
layer_parametrizations: list[dict] = None,
|
||||||
# [
|
# [
|
||||||
# {
|
# {
|
||||||
@@ -64,6 +67,7 @@ class regenerator(nn.Module):
|
|||||||
activation_function=util.complexNN.Pow,
|
activation_function=util.complexNN.Pow,
|
||||||
dtype=torch.float64,
|
dtype=torch.float64,
|
||||||
dropout_prob=0.01,
|
dropout_prob=0.01,
|
||||||
|
scale=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super(regenerator, self).__init__()
|
super(regenerator, self).__init__()
|
||||||
@@ -74,39 +78,57 @@ class regenerator(nn.Module):
|
|||||||
raise ValueError("dims must be provided")
|
raise ValueError("dims must be provided")
|
||||||
self._n_hidden_layers = len(dims) - 2
|
self._n_hidden_layers = len(dims) - 2
|
||||||
self._layers = nn.Sequential()
|
self._layers = nn.Sequential()
|
||||||
|
if layer_kwargs is None:
|
||||||
|
layer_kwargs = {}
|
||||||
|
# self.powers = []
|
||||||
|
|
||||||
for i in range(self._n_hidden_layers + 1):
|
for i in range(self._n_hidden_layers + 1):
|
||||||
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype))
|
if scale:
|
||||||
|
self._layers.append(util.complexNN.Scale(dims[i]))
|
||||||
|
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs))
|
||||||
if i < self._n_hidden_layers:
|
if i < self._n_hidden_layers:
|
||||||
if dropout_prob is not None:
|
if dropout_prob is not None:
|
||||||
self._layers.append(util.complexNN.DropoutComplex(p=dropout_prob))
|
self._layers.append(util.complexNN.DropoutComplex(p=dropout_prob))
|
||||||
self._layers.append(activation_function())
|
self._layers.append(activation_function(bias=True, size=dims[i + 1]))
|
||||||
|
|
||||||
|
self._layers.append(util.complexNN.Scale(dims[-1]))
|
||||||
|
|
||||||
# add parametrizations
|
# add parametrizations
|
||||||
if layer_parametrizations is not None:
|
if layer_parametrizations is not None:
|
||||||
|
for layer in self._layers:
|
||||||
for layer_parametrization in layer_parametrizations:
|
for layer_parametrization in layer_parametrizations:
|
||||||
tensor_name = layer_parametrization.get("tensor_name", None)
|
tensor_name = layer_parametrization.get("tensor_name", None)
|
||||||
parametrization = layer_parametrization.get("parametrization", None)
|
parametrization = layer_parametrization.get("parametrization", None)
|
||||||
param_kwargs = layer_parametrization.get("kwargs", {})
|
param_kwargs = layer_parametrization.get("kwargs", {})
|
||||||
if (
|
if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None:
|
||||||
tensor_name is not None
|
parametrization(layer, tensor_name, **param_kwargs)
|
||||||
and tensor_name in self._layers[-1]._parameters
|
|
||||||
and parametrization is not None
|
|
||||||
):
|
|
||||||
parametrization(self._layers[-1], tensor_name, **param_kwargs)
|
|
||||||
|
|
||||||
def forward(self, input_x):
|
# def __call__(self, input_x, **kwargs):
|
||||||
|
# return self.forward(input_x, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, input_x, trace_powers=False):
|
||||||
x = input_x
|
x = input_x
|
||||||
|
|
||||||
|
if trace_powers:
|
||||||
|
powers = [x.abs().square().sum()]
|
||||||
|
|
||||||
# check if tracing
|
# check if tracing
|
||||||
if torch.jit.is_tracing():
|
if torch.jit.is_tracing():
|
||||||
for layer in self._layers:
|
for layer in self._layers:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
if trace_powers:
|
||||||
|
powers.append(x.abs().square().sum())
|
||||||
else:
|
else:
|
||||||
# with torch.nn.utils.parametrize.cached():
|
# with torch.nn.utils.parametrize.cached():
|
||||||
for layer in self._layers:
|
for layer in self._layers:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
if trace_powers:
|
||||||
|
powers.append(x.abs().square().sum())
|
||||||
|
if trace_powers:
|
||||||
|
return x, powers
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def traverse_dict_update(target, source):
|
def traverse_dict_update(target, source):
|
||||||
for k, v in source.items():
|
for k, v in source.items():
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
@@ -119,6 +141,7 @@ def traverse_dict_update(target, source):
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
target.__dict__[k] = v
|
target.__dict__[k] = v
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -142,7 +165,7 @@ class Trainer:
|
|||||||
OptimizerSettings,
|
OptimizerSettings,
|
||||||
PytorchSettings,
|
PytorchSettings,
|
||||||
regenerator,
|
regenerator,
|
||||||
torch.nn.utils.parametrizations.orthogonal
|
torch.nn.utils.parametrizations.orthogonal,
|
||||||
])
|
])
|
||||||
if self.resume:
|
if self.resume:
|
||||||
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
|
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
|
||||||
@@ -206,6 +229,11 @@ class Trainer:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def define_model(self, model_kwargs=None):
|
def define_model(self, model_kwargs=None):
|
||||||
|
if self.resume:
|
||||||
|
model_kwargs = self.checkpoint_dict["model_kwargs"]
|
||||||
|
else:
|
||||||
|
model_kwargs = model_kwargs
|
||||||
|
|
||||||
if model_kwargs is None:
|
if model_kwargs is None:
|
||||||
n_hidden_layers = self.model_settings.n_hidden_layers
|
n_hidden_layers = self.model_settings.n_hidden_layers
|
||||||
|
|
||||||
@@ -228,6 +256,7 @@ class Trainer:
|
|||||||
"activation_function": afunc,
|
"activation_function": afunc,
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
"dropout_prob": self.model_settings.dropout_prob,
|
"dropout_prob": self.model_settings.dropout_prob,
|
||||||
|
"scale": self.model_settings.scale,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
self.model_kwargs = model_kwargs
|
self.model_kwargs = model_kwargs
|
||||||
@@ -237,9 +266,12 @@ class Trainer:
|
|||||||
# dims = self.model_kwargs.pop("dims")
|
# dims = self.model_kwargs.pop("dims")
|
||||||
self.model = regenerator(**self.model_kwargs)
|
self.model = regenerator(**self.model_kwargs)
|
||||||
|
|
||||||
|
if self.writer is not None:
|
||||||
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
|
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
|
||||||
|
|
||||||
self.model = self.model.to(self.pytorch_settings.device)
|
self.model = self.model.to(self.pytorch_settings.device)
|
||||||
|
if self.resume:
|
||||||
|
self.model.load_state_dict(self.checkpoint_dict["model_state_dict"], strict=False)
|
||||||
|
|
||||||
def get_sliced_data(self, override=None):
|
def get_sliced_data(self, override=None):
|
||||||
symbols = self.data_settings.symbols
|
symbols = self.data_settings.symbols
|
||||||
@@ -253,11 +285,13 @@ class Trainer:
|
|||||||
dtype = getattr(torch, self.data_settings.dtype)
|
dtype = getattr(torch, self.data_settings.dtype)
|
||||||
|
|
||||||
num_symbols = None
|
num_symbols = None
|
||||||
|
config_path = self.data_settings.config_path
|
||||||
if override is not None:
|
if override is not None:
|
||||||
num_symbols = override.get("num_symbols", None)
|
num_symbols = override.get("num_symbols", None)
|
||||||
|
config_path = override.get("config_path", config_path)
|
||||||
# get dataset
|
# get dataset
|
||||||
dataset = FiberRegenerationDataset(
|
dataset = FiberRegenerationDataset(
|
||||||
file_path=self.data_settings.config_path,
|
file_path=config_path,
|
||||||
symbols=symbols,
|
symbols=symbols,
|
||||||
output_dim=data_size,
|
output_dim=data_size,
|
||||||
target_delay=in_out_delay,
|
target_delay=in_out_delay,
|
||||||
@@ -330,10 +364,11 @@ class Trainer:
|
|||||||
task = progress.add_task("-.---e--", total=len(train_loader))
|
task = progress.add_task("-.---e--", total=len(train_loader))
|
||||||
progress.start()
|
progress.start()
|
||||||
|
|
||||||
running_loss2 = 0.0
|
# running_loss2 = 0.0
|
||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
self.model.train()
|
self.model.train()
|
||||||
for batch_idx, (x, y) in enumerate(train_loader):
|
loader_len = len(train_loader)
|
||||||
|
for batch_idx, (x, y, _) in enumerate(train_loader):
|
||||||
self.model.zero_grad(set_to_none=True)
|
self.model.zero_grad(set_to_none=True)
|
||||||
x, y = (
|
x, y = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
@@ -344,24 +379,23 @@ class Trainer:
|
|||||||
loss_value = loss.item()
|
loss_value = loss.item()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
running_loss2 += loss_value
|
# running_loss2 += loss_value
|
||||||
running_loss += loss_value
|
running_loss += loss_value
|
||||||
|
|
||||||
if enable_progress:
|
if enable_progress:
|
||||||
progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
progress.update(task, advance=1, description=f"{running_loss/(batch_idx+1):.3e}")
|
||||||
|
|
||||||
if batch_idx % self.pytorch_settings.write_every == 0:
|
if batch_idx % self.pytorch_settings.write_every == 0:
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"training loss",
|
"training loss",
|
||||||
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
running_loss / (batch_idx + 1),
|
||||||
epoch * len(train_loader) + batch_idx,
|
epoch * loader_len + batch_idx,
|
||||||
)
|
)
|
||||||
running_loss2 = 0.0
|
|
||||||
|
|
||||||
if enable_progress:
|
if enable_progress:
|
||||||
progress.stop()
|
progress.stop()
|
||||||
|
|
||||||
return running_loss / len(train_loader)
|
return running_loss / (batch_idx + 1)
|
||||||
|
|
||||||
def eval_model(self, valid_loader, epoch, enable_progress=True):
|
def eval_model(self, valid_loader, epoch, enable_progress=True):
|
||||||
if enable_progress:
|
if enable_progress:
|
||||||
@@ -384,7 +418,7 @@ class Trainer:
|
|||||||
self.model.eval()
|
self.model.eval()
|
||||||
running_error = 0
|
running_error = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, (x, y) in enumerate(valid_loader):
|
for batch_idx, (x, y, _) in enumerate(valid_loader):
|
||||||
x, y = (
|
x, y = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
@@ -395,15 +429,17 @@ class Trainer:
|
|||||||
running_error += error_value
|
running_error += error_value
|
||||||
|
|
||||||
if enable_progress:
|
if enable_progress:
|
||||||
progress.update(task, advance=1, description=f"{error_value:.3e}")
|
progress.update(task, advance=1, description=f"{error_value/(batch_idx+1):.3e}")
|
||||||
|
|
||||||
|
running_error /= (batch_idx+1)
|
||||||
|
|
||||||
running_error /= len(valid_loader)
|
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"eval loss",
|
"eval loss",
|
||||||
running_error,
|
running_error,
|
||||||
epoch,
|
epoch,
|
||||||
)
|
)
|
||||||
|
if (epoch + 1) % 10 == 0 or epoch < 10:
|
||||||
|
# plotting is slow, so only do it every 10 epochs
|
||||||
title_append, subtitle = self.build_title(epoch + 1)
|
title_append, subtitle = self.build_title(epoch + 1)
|
||||||
self.writer.add_figure(
|
self.writer.add_figure(
|
||||||
"fiber response",
|
"fiber response",
|
||||||
@@ -426,45 +462,74 @@ class Trainer:
|
|||||||
),
|
),
|
||||||
epoch + 1,
|
epoch + 1,
|
||||||
)
|
)
|
||||||
self.writer_histograms(epoch + 1)
|
|
||||||
|
self.writer.add_figure(
|
||||||
|
"powers",
|
||||||
|
self.plot_model_response(
|
||||||
|
model=self.model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
mode="powers",
|
||||||
|
show=False,
|
||||||
|
),
|
||||||
|
epoch + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.write_parameters(epoch + 1)
|
||||||
|
self.writer.flush()
|
||||||
|
|
||||||
if enable_progress:
|
if enable_progress:
|
||||||
progress.stop()
|
progress.stop()
|
||||||
|
|
||||||
return running_error
|
return running_error
|
||||||
|
|
||||||
def run_model(self, model, loader):
|
def run_model(self, model, loader, trace_powers=False):
|
||||||
model.eval()
|
model.eval()
|
||||||
xs = []
|
fiber_out = []
|
||||||
ys = []
|
fiber_in = []
|
||||||
y_preds = []
|
regen = []
|
||||||
|
timestamps = []
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model = model.to(self.pytorch_settings.device)
|
model = model.to(self.pytorch_settings.device)
|
||||||
for x, y in loader:
|
for x, y, timestamp in loader:
|
||||||
x, y = (
|
x, y = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
y_pred = model(x).cpu()
|
if trace_powers:
|
||||||
|
y_pred, powers = model(x, trace_powers).cpu()
|
||||||
|
else:
|
||||||
|
y_pred = model(x, trace_powers).cpu()
|
||||||
# x = x.cpu()
|
# x = x.cpu()
|
||||||
# y = y.cpu()
|
# y = y.cpu()
|
||||||
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||||
y = y.view(y.shape[0], -1, 2)
|
y = y.view(y.shape[0], -1, 2)
|
||||||
x = x.view(x.shape[0], -1, 2)
|
x = x.view(x.shape[0], -1, 2)
|
||||||
xs.append(x[:, 0, :].squeeze())
|
# timestamp = timestamp.view(-1, 1)
|
||||||
ys.append(y.squeeze())
|
fiber_out.append(x[:, x.shape[1] // 2, :].squeeze())
|
||||||
y_preds.append(y_pred.squeeze())
|
fiber_in.append(y.squeeze())
|
||||||
|
regen.append(y_pred.squeeze())
|
||||||
|
timestamps.append(timestamp.squeeze())
|
||||||
|
|
||||||
xs = torch.vstack(xs).cpu()
|
fiber_out = torch.vstack(fiber_out).cpu()
|
||||||
ys = torch.vstack(ys).cpu()
|
fiber_in = torch.vstack(fiber_in).cpu()
|
||||||
y_preds = torch.vstack(y_preds).cpu()
|
regen = torch.vstack(regen).cpu()
|
||||||
return ys, xs, y_preds
|
timestamps = torch.concat(timestamps).cpu()
|
||||||
|
if trace_powers:
|
||||||
|
return fiber_in, fiber_out, regen, timestamps, powers
|
||||||
|
return fiber_in, fiber_out, regen, timestamps
|
||||||
|
|
||||||
def writer_histograms(self, epoch, attributes=["weight", "weight_U", "weight_V", "bias", "sigma", "scale"]):
|
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
|
||||||
for i, layer in enumerate(self.model._layers):
|
for i, layer in enumerate(self.model._layers):
|
||||||
tag = f"layer {i}"
|
tag = f"layer {i}"
|
||||||
for attribute in attributes:
|
if hasattr(layer, "parametrizations"):
|
||||||
if hasattr(layer, attribute):
|
attribute_pool = set(layer.parametrizations._modules) | set(layer._parameters)
|
||||||
|
else:
|
||||||
|
attribute_pool = set(layer._parameters)
|
||||||
|
for attribute in attribute_pool:
|
||||||
|
plot = (attributes is None) or (attribute in attributes)
|
||||||
|
if plot:
|
||||||
vals: np.ndarray = getattr(layer, attribute).detach().cpu().numpy().flatten()
|
vals: np.ndarray = getattr(layer, attribute).detach().cpu().numpy().flatten()
|
||||||
if vals.ndim <= 1 and len(vals) == 1:
|
if vals.ndim <= 1 and len(vals) == 1:
|
||||||
if np.iscomplexobj(vals):
|
if np.iscomplexobj(vals):
|
||||||
@@ -483,14 +548,11 @@ class Trainer:
|
|||||||
if self.writer is None:
|
if self.writer is None:
|
||||||
self.setup_tb_writer()
|
self.setup_tb_writer()
|
||||||
|
|
||||||
if self.resume:
|
self.define_model()
|
||||||
model_kwargs = self.checkpoint_dict["model_kwargs"]
|
|
||||||
else:
|
|
||||||
model_kwargs = None
|
|
||||||
|
|
||||||
self.define_model(model_kwargs=model_kwargs)
|
print(
|
||||||
|
f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})"
|
||||||
print(f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})")
|
)
|
||||||
|
|
||||||
title_append, subtitle = self.build_title(0)
|
title_append, subtitle = self.build_title(0)
|
||||||
|
|
||||||
@@ -515,36 +577,55 @@ class Trainer:
|
|||||||
),
|
),
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
self.writer_histograms(0)
|
|
||||||
|
self.writer.add_figure(
|
||||||
|
"powers",
|
||||||
|
self.plot_model_response(
|
||||||
|
model=self.model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
mode="powers",
|
||||||
|
show=False,
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.write_parameters(0)
|
||||||
|
|
||||||
|
self.writer.add_text("datasets", '\n'.join(self.data_settings.config_path))
|
||||||
|
|
||||||
|
self.writer.flush()
|
||||||
|
|
||||||
train_loader, valid_loader = self.get_sliced_data()
|
train_loader, valid_loader = self.get_sliced_data()
|
||||||
|
|
||||||
optimizer_name = self.optimizer_settings.optimizer
|
optimizer_name = self.optimizer_settings.optimizer
|
||||||
|
|
||||||
lr = self.optimizer_settings.learning_rate
|
# lr = self.optimizer_settings.learning_rate
|
||||||
|
|
||||||
self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr)
|
self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)(
|
||||||
|
self.model.parameters(), **self.optimizer_settings.optimizer_kwargs
|
||||||
|
)
|
||||||
if self.optimizer_settings.scheduler is not None:
|
if self.optimizer_settings.scheduler is not None:
|
||||||
self.scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
self.scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
||||||
self.optimizer, **self.optimizer_settings.scheduler_kwargs
|
self.optimizer, **self.optimizer_settings.scheduler_kwargs
|
||||||
)
|
)
|
||||||
if self.resume:
|
# if self.resume:
|
||||||
try:
|
# try:
|
||||||
self.scheduler.load_state_dict(self.checkpoint_dict["scheduler_state_dict"])
|
# self.scheduler.load_state_dict(self.checkpoint_dict["scheduler_state_dict"])
|
||||||
except ValueError:
|
# except ValueError:
|
||||||
pass
|
# pass
|
||||||
self.writer.add_scalar("learning rate", self.scheduler.get_last_lr()[0], -1)
|
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], -1)
|
||||||
|
|
||||||
|
|
||||||
if not self.resume:
|
if not self.resume:
|
||||||
self.best = self.build_checkpoint_dict()
|
self.best = self.build_checkpoint_dict()
|
||||||
else:
|
else:
|
||||||
self.best = self.checkpoint_dict
|
self.best = self.checkpoint_dict
|
||||||
self.model.load_state_dict(self.best["model_state_dict"], strict=False)
|
self.best["loss"] = float("inf")
|
||||||
try:
|
# self.model.load_state_dict(self.best["model_state_dict"], strict=False)
|
||||||
self.optimizer.load_state_dict(self.best["optimizer_state_dict"])
|
# try:
|
||||||
except ValueError:
|
# self.optimizer.load_state_dict(self.best["optimizer_state_dict"])
|
||||||
pass
|
# except ValueError:
|
||||||
|
# pass
|
||||||
|
|
||||||
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
|
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
|
||||||
enable_progress = True
|
enable_progress = True
|
||||||
@@ -562,12 +643,8 @@ class Trainer:
|
|||||||
enable_progress=enable_progress,
|
enable_progress=enable_progress,
|
||||||
)
|
)
|
||||||
if self.optimizer_settings.scheduler is not None:
|
if self.optimizer_settings.scheduler is not None:
|
||||||
lr_old = self.scheduler.get_last_lr()
|
|
||||||
self.scheduler.step(loss)
|
self.scheduler.step(loss)
|
||||||
lr_new = self.scheduler.get_last_lr()
|
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch)
|
||||||
if lr_old[0] != lr_new[0]:
|
|
||||||
self.writer.add_scalar("learning rate", lr_new[0], epoch)
|
|
||||||
|
|
||||||
if self.pytorch_settings.save_models and self.model is not None:
|
if self.pytorch_settings.save_models and self.model is not None:
|
||||||
save_path = (
|
save_path = (
|
||||||
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
|
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
|
||||||
@@ -588,7 +665,28 @@ class Trainer:
|
|||||||
self.writer.close()
|
self.writer.close()
|
||||||
return self.best
|
return self.best
|
||||||
|
|
||||||
def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
def _plot_model_response_powers(self, powers, layer_names, title_append="", subtitle="", show=True):
|
||||||
|
powers = [power / powers[0] for power in powers]
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
fig.set_figwidth(18)
|
||||||
|
fig.suptitle(
|
||||||
|
f"Energy conservation{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
|
||||||
|
)
|
||||||
|
ax.semilogy(powers, marker="o")
|
||||||
|
ax.set_xticks(range(len(layer_names)), layer_names, rotation=90)
|
||||||
|
ax.set_xlabel("Layer")
|
||||||
|
ax.set_ylabel("Normailzed Power")
|
||||||
|
ax.grid(which="major", axis="x")
|
||||||
|
ax.grid(which="major", axis="y")
|
||||||
|
ax.grid(which="minor", axis="y", linestyle=":")
|
||||||
|
fig.tight_layout()
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def _plot_model_response_eye(
|
||||||
|
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
|
||||||
|
):
|
||||||
if sps is None:
|
if sps is None:
|
||||||
raise ValueError("sps must be provided")
|
raise ValueError("sps must be provided")
|
||||||
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||||
@@ -603,16 +701,67 @@ class Trainer:
|
|||||||
if not any(labels):
|
if not any(labels):
|
||||||
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
||||||
|
|
||||||
|
x_bins = np.linspace(0, 2, 2 * sps, endpoint=False)
|
||||||
|
y_bins = np.zeros((2 * len(signals), 1000))
|
||||||
|
eye_data = np.zeros((2 * len(signals), 1000, 2 * sps))
|
||||||
|
# signals = [signal.cpu().numpy() for signal in signals]
|
||||||
|
for i in range(len(signals) * 2):
|
||||||
|
eye_signal = signals[i // 2][:, i % 2] # x, y, x, y, ...
|
||||||
|
eye_signal = np.real(np.square(np.abs(eye_signal)))
|
||||||
|
data_min = np.min(eye_signal)
|
||||||
|
data_max = np.max(eye_signal)
|
||||||
|
y_bins[i] = np.linspace(data_min, data_max, 1000, endpoint=False)
|
||||||
|
for j in range(len(timestamps)):
|
||||||
|
t = timestamps[j] / sps
|
||||||
|
val = eye_signal[j]
|
||||||
|
x = np.digitize(t % 2, x_bins) - 1
|
||||||
|
y = np.digitize(val, y_bins[i]) - 1
|
||||||
|
eye_data[i][y][x] += 1
|
||||||
|
|
||||||
|
cmap = LinearSegmentedColormap.from_list(
|
||||||
|
"eyemap",
|
||||||
|
[
|
||||||
|
(0, "white"),
|
||||||
|
(0.001, "dodgerblue"),
|
||||||
|
(0.1, "blue"),
|
||||||
|
(0.2, "cyan"),
|
||||||
|
(0.5, "lime"),
|
||||||
|
(0.8, "gold"),
|
||||||
|
(1, "red"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# ordering = np.argsort(timestamps)
|
||||||
|
# signals = [signal[ordering] for signal in signals]
|
||||||
|
# timestamps = timestamps[ordering]
|
||||||
|
|
||||||
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
|
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
|
||||||
fig.set_figwidth(18)
|
fig.set_figwidth(18)
|
||||||
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||||
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
|
# xaxis = timestamps / sps
|
||||||
for j, (label, signal) in enumerate(zip(labels, signals)):
|
# xaxis = np.arange(2 * sps) / sps
|
||||||
|
for j, label in enumerate(labels):
|
||||||
|
x = eye_data[2 * j]
|
||||||
|
y = eye_data[2 * j + 1]
|
||||||
|
# x, y = signal.T
|
||||||
# signal = signal.cpu().numpy()
|
# signal = signal.cpu().numpy()
|
||||||
for i in range(len(signal) // sps - 1):
|
# for i in range(len(signal) // sps - 1):
|
||||||
x, y = signal[i * sps : (i + 2) * sps].T
|
# x, y = signal[i * sps : (i + 2) * sps].T
|
||||||
axs[0 + 2 * j].plot(xaxis, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10)
|
# axs[0 + 2 * j].scatter((timestamps/sps) % 2, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1)
|
||||||
axs[1 + 2 * j].plot(xaxis, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10)
|
# axs[1 + 2 * j].scatter((timestamps/sps) % 2, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1)
|
||||||
|
axs[0 + 2 * j].imshow(
|
||||||
|
x, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j][0], y_bins[2 * j][-1]]
|
||||||
|
)
|
||||||
|
axs[1 + 2 * j].imshow(
|
||||||
|
y, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j + 1][0], y_bins[2 * j + 1][-1]]
|
||||||
|
)
|
||||||
|
axs[0 + 2 * j].set_xlim((x_bins[0], x_bins[-1]))
|
||||||
|
axs[1 + 2 * j].set_xlim((x_bins[0], x_bins[-1]))
|
||||||
|
ymin = np.min(y_bins[:, 0])
|
||||||
|
ymax = np.max(y_bins[:, -1])
|
||||||
|
ydiff = ymax - ymin
|
||||||
|
axs[0 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff))
|
||||||
|
axs[1 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff))
|
||||||
axs[0 + 2 * j].set_title(label + " x")
|
axs[0 + 2 * j].set_title(label + " x")
|
||||||
axs[1 + 2 * j].set_title(label + " y")
|
axs[1 + 2 * j].set_title(label + " y")
|
||||||
axs[0 + 2 * j].set_xlabel("Symbol")
|
axs[0 + 2 * j].set_xlabel("Symbol")
|
||||||
@@ -627,7 +776,9 @@ class Trainer:
|
|||||||
plt.show()
|
plt.show()
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
def _plot_model_response_head(
|
||||||
|
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
|
||||||
|
):
|
||||||
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||||
labels = [labels]
|
labels = [labels]
|
||||||
else:
|
else:
|
||||||
@@ -640,19 +791,29 @@ class Trainer:
|
|||||||
if not any(labels):
|
if not any(labels):
|
||||||
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
||||||
|
|
||||||
|
ordering = np.argsort(timestamps)
|
||||||
|
signals = [signal[ordering] for signal in signals]
|
||||||
|
timestamps = timestamps[ordering]
|
||||||
|
|
||||||
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
|
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
|
||||||
fig.set_figwidth(18)
|
fig.set_figwidth(18)
|
||||||
fig.set_figheight(4)
|
fig.set_figheight(4)
|
||||||
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||||
for i, ax in enumerate(axs):
|
for i, ax in enumerate(axs):
|
||||||
|
ax: plt.Axes
|
||||||
for signal, label in zip(signals, labels):
|
for signal, label in zip(signals, labels):
|
||||||
if sps is not None:
|
if sps is not None:
|
||||||
xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False)
|
xaxis = timestamps / sps
|
||||||
else:
|
else:
|
||||||
xaxis = np.arange(len(signal))
|
xaxis = timestamps
|
||||||
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
|
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
|
||||||
ax.set_xlabel("Sample" if sps is None else "Symbol")
|
ax.set_xlabel("Sample" if sps is None else "Symbol")
|
||||||
ax.set_ylabel("normalized power")
|
ax.set_ylabel("normalized power")
|
||||||
|
ax.minorticks_on()
|
||||||
|
ax.tick_params(axis="y", which="minor", left=False, right=False)
|
||||||
|
ax.grid(which="major", axis="x")
|
||||||
|
ax.grid(which="minor", axis="x", linestyle=":")
|
||||||
|
ax.grid(which="major", axis="y")
|
||||||
ax.legend(loc="upper right")
|
ax.legend(loc="upper right")
|
||||||
fig.tight_layout()
|
fig.tight_layout()
|
||||||
if show:
|
if show:
|
||||||
@@ -664,22 +825,51 @@ class Trainer:
|
|||||||
model=None,
|
model=None,
|
||||||
title_append="",
|
title_append="",
|
||||||
subtitle="",
|
subtitle="",
|
||||||
mode: Literal["eye", "head"] = "head",
|
mode: Literal["eye", "head", "powers"] = "head",
|
||||||
show=False,
|
show=False,
|
||||||
):
|
):
|
||||||
|
if mode == "powers":
|
||||||
|
input_data = torch.ones(
|
||||||
|
1, 2 * self.data_settings.output_size, dtype=getattr(torch, self.data_settings.dtype)
|
||||||
|
).to(self.pytorch_settings.device)
|
||||||
|
model = model.to(self.pytorch_settings.device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
_, powers = model(input_data, trace_powers=True)
|
||||||
|
|
||||||
|
powers = [power.item() for power in powers]
|
||||||
|
layer_names = ["input", *[str(x).split("(")[0] for x in model._layers._modules.values()]]
|
||||||
|
|
||||||
|
# remove dropout layers
|
||||||
|
mask = [1 if "Dropout" not in layer_name else 0 for layer_name in layer_names]
|
||||||
|
layer_names = [layer_name for layer_name, m in zip(layer_names, mask) if m]
|
||||||
|
powers = [power for power, m in zip(powers, mask) if m]
|
||||||
|
|
||||||
|
fig = self._plot_model_response_powers(
|
||||||
|
powers, layer_names, title_append=title_append, subtitle=subtitle, show=show
|
||||||
|
)
|
||||||
|
return fig
|
||||||
|
|
||||||
data_settings_backup = copy.deepcopy(self.data_settings)
|
data_settings_backup = copy.deepcopy(self.data_settings)
|
||||||
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
||||||
self.data_settings.drop_first = 100 * 128
|
self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
|
||||||
self.data_settings.shuffle = False
|
self.data_settings.shuffle = False
|
||||||
self.data_settings.train_split = 1.0
|
self.data_settings.train_split = 1.0
|
||||||
self.pytorch_settings.batchsize = (
|
self.pytorch_settings.batchsize = (
|
||||||
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
|
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
|
||||||
)
|
)
|
||||||
plot_loader, _ = self.get_sliced_data(override={"num_symbols": self.pytorch_settings.batchsize})
|
config_path = random.choice(self.data_settings.config_path)
|
||||||
|
fiber_length = int(float(str(config_path).split('-')[-7])/1000)
|
||||||
|
plot_loader, _ = self.get_sliced_data(
|
||||||
|
override={
|
||||||
|
"num_symbols": self.pytorch_settings.batchsize,
|
||||||
|
"config_path": config_path,
|
||||||
|
}
|
||||||
|
)
|
||||||
self.data_settings = data_settings_backup
|
self.data_settings = data_settings_backup
|
||||||
self.pytorch_settings = pytorch_settings_backup
|
self.pytorch_settings = pytorch_settings_backup
|
||||||
|
|
||||||
fiber_in, fiber_out, regen = self.run_model(model, plot_loader)
|
fiber_in, fiber_out, regen, timestamps = self.run_model(model, plot_loader)
|
||||||
fiber_in = fiber_in.view(-1, 2)
|
fiber_in = fiber_in.view(-1, 2)
|
||||||
fiber_out = fiber_out.view(-1, 2)
|
fiber_out = fiber_out.view(-1, 2)
|
||||||
regen = regen.view(-1, 2)
|
regen = regen.view(-1, 2)
|
||||||
@@ -687,6 +877,7 @@ class Trainer:
|
|||||||
fiber_in = fiber_in.numpy()
|
fiber_in = fiber_in.numpy()
|
||||||
fiber_out = fiber_out.numpy()
|
fiber_out = fiber_out.numpy()
|
||||||
regen = regen.numpy()
|
regen = regen.numpy()
|
||||||
|
timestamps = timestamps.numpy()
|
||||||
|
|
||||||
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
|
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
|
||||||
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
|
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
|
||||||
@@ -697,9 +888,10 @@ class Trainer:
|
|||||||
fiber_in,
|
fiber_in,
|
||||||
fiber_out,
|
fiber_out,
|
||||||
regen,
|
regen,
|
||||||
|
timestamps=timestamps,
|
||||||
labels=("fiber in", "fiber out", "regen"),
|
labels=("fiber in", "fiber out", "regen"),
|
||||||
sps=plot_loader.dataset.samples_per_symbol,
|
sps=plot_loader.dataset.samples_per_symbol,
|
||||||
title_append=title_append,
|
title_append=title_append + f" ({fiber_length} km)",
|
||||||
subtitle=subtitle,
|
subtitle=subtitle,
|
||||||
show=show,
|
show=show,
|
||||||
)
|
)
|
||||||
@@ -709,9 +901,10 @@ class Trainer:
|
|||||||
fiber_in,
|
fiber_in,
|
||||||
fiber_out,
|
fiber_out,
|
||||||
regen,
|
regen,
|
||||||
|
timestamps=timestamps,
|
||||||
labels=("fiber in", "fiber out", "regen"),
|
labels=("fiber in", "fiber out", "regen"),
|
||||||
sps=plot_loader.dataset.samples_per_symbol,
|
sps=plot_loader.dataset.samples_per_symbol,
|
||||||
title_append=title_append,
|
title_append=title_append + f" ({fiber_length} km)",
|
||||||
subtitle=subtitle,
|
subtitle=subtitle,
|
||||||
show=show,
|
show=show,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
import matplotlib
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
from hypertraining.settings import (
|
from hypertraining.settings import (
|
||||||
GlobalSettings,
|
GlobalSettings,
|
||||||
DataSettings,
|
DataSettings,
|
||||||
@@ -7,16 +10,20 @@ from hypertraining.settings import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from hypertraining.training import Trainer
|
from hypertraining.training import Trainer
|
||||||
import torch
|
|
||||||
|
# import torch
|
||||||
import json
|
import json
|
||||||
import util
|
import util
|
||||||
|
|
||||||
|
from rich import print as rprint
|
||||||
|
|
||||||
global_settings = GlobalSettings(
|
global_settings = GlobalSettings(
|
||||||
seed=42,
|
seed=0xC0FFEE,
|
||||||
)
|
)
|
||||||
|
|
||||||
data_settings = DataSettings(
|
data_settings = DataSettings(
|
||||||
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
# config_path="data/*-128-16384-50000-0-0-17-0-PAM4-0.ini",
|
||||||
|
config_path=[f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in (40000, 50000, 60000)],
|
||||||
dtype="complex64",
|
dtype="complex64",
|
||||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||||
symbols=13, # study: single_core_regen_20241123_011232
|
symbols=13, # study: single_core_regen_20241123_011232
|
||||||
@@ -45,55 +52,83 @@ model_settings = ModelSettings(
|
|||||||
output_dim=2,
|
output_dim=2,
|
||||||
n_hidden_layers=4,
|
n_hidden_layers=4,
|
||||||
overrides={
|
overrides={
|
||||||
"n_hidden_nodes_0": 8,
|
"n_hidden_nodes_0": 4,
|
||||||
"n_hidden_nodes_1": 8,
|
"n_hidden_nodes_1": 4,
|
||||||
"n_hidden_nodes_2": 4,
|
"n_hidden_nodes_2": 4,
|
||||||
"n_hidden_nodes_3": 6,
|
"n_hidden_nodes_3": 4,
|
||||||
},
|
},
|
||||||
model_activation_func="PowScale",
|
model_activation_func="EOActivation",
|
||||||
# dropout_prob=0.01,
|
dropout_prob=0.01,
|
||||||
model_layer_function="ONN",
|
model_layer_function="ONNRect",
|
||||||
|
model_layer_kwargs={"square": True},
|
||||||
|
scale=True,
|
||||||
model_layer_parametrizations=[
|
model_layer_parametrizations=[
|
||||||
{
|
{
|
||||||
"tensor_name": "weight",
|
"tensor_name": "weight",
|
||||||
"parametrization": torch.nn.utils.parametrizations.orthogonal,
|
"parametrization": util.complexNN.energy_conserving,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "alpha",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "gain",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
"kwargs": {
|
||||||
|
"min": 0,
|
||||||
|
"max": float("inf"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "phase_bias",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
"kwargs": {
|
||||||
|
"min": 0,
|
||||||
|
"max": 2 * torch.pi,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"tensor_name": "scales",
|
"tensor_name": "scales",
|
||||||
"parametrization": util.complexNN.clamp,
|
"parametrization": util.complexNN.clamp,
|
||||||
},
|
},
|
||||||
{
|
# {
|
||||||
"tensor_name": "scale",
|
# "tensor_name": "scale",
|
||||||
"parametrization": util.complexNN.clamp,
|
# "parametrization": util.complexNN.clamp,
|
||||||
},
|
# },
|
||||||
{
|
# {
|
||||||
"tensor_name": "bias",
|
# "tensor_name": "bias",
|
||||||
"parametrization": util.complexNN.clamp,
|
# "parametrization": util.complexNN.clamp,
|
||||||
},
|
# },
|
||||||
# {
|
# {
|
||||||
# "tensor_name": "V",
|
# "tensor_name": "V",
|
||||||
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
|
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
|
||||||
# },
|
# },
|
||||||
# {
|
{
|
||||||
# "tensor_name": "S",
|
"tensor_name": "loss",
|
||||||
# "parametrization": util.complexNN.clamp,
|
"parametrization": util.complexNN.clamp,
|
||||||
# },
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer_settings = OptimizerSettings(
|
optimizer_settings = OptimizerSettings(
|
||||||
optimizer="Adam",
|
optimizer="AdamW",
|
||||||
learning_rate=0.05,
|
optimizer_kwargs={
|
||||||
|
"lr": 0.05,
|
||||||
|
"amsgrad": True,
|
||||||
|
# "weight_decay": 1e-7,
|
||||||
|
},
|
||||||
|
# learning_rate=0.05,
|
||||||
scheduler="ReduceLROnPlateau",
|
scheduler="ReduceLROnPlateau",
|
||||||
scheduler_kwargs={
|
scheduler_kwargs={
|
||||||
"patience": 2**6,
|
"patience": 2**6,
|
||||||
"factor": 0.9,
|
"factor": 0.75,
|
||||||
# "threshold": 1e-3,
|
# "threshold": 1e-3,
|
||||||
"min_lr": 1e-6,
|
"min_lr": 1e-6,
|
||||||
"cooldown": 10,
|
"cooldown": 10,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_dict_to_file(dictionary, filename):
|
def save_dict_to_file(dictionary, filename):
|
||||||
"""
|
"""
|
||||||
Save the best dictionary to a JSON file.
|
Save the best dictionary to a JSON file.
|
||||||
@@ -103,28 +138,79 @@ def save_dict_to_file(dictionary, filename):
|
|||||||
:param filename: Path to the JSON file where the dictionary will be saved.
|
:param filename: Path to the JSON file where the dictionary will be saved.
|
||||||
:type filename: str
|
:type filename: str
|
||||||
"""
|
"""
|
||||||
with open(filename, 'w') as f:
|
with open(filename, "w") as f:
|
||||||
json.dump(dictionary, f, indent=4)
|
json.dump(dictionary, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def sweep_lengths(*lengths, model=None):
|
||||||
|
assert model is not None, "Model must be provided."
|
||||||
|
model = model
|
||||||
|
|
||||||
|
fiber_ins = {}
|
||||||
|
fiber_outs = {}
|
||||||
|
regens = {}
|
||||||
|
timestampss = {}
|
||||||
|
|
||||||
|
for length in lengths:
|
||||||
|
trainer = Trainer(
|
||||||
|
checkpoint_path=model,
|
||||||
|
settings_override={
|
||||||
|
"data_settings": {
|
||||||
|
"config_path": f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini",
|
||||||
|
"train_split": 1,
|
||||||
|
"shuffle": True,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
trainer.define_model()
|
||||||
|
loader, _ = trainer.get_sliced_data()
|
||||||
|
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
|
||||||
|
|
||||||
|
fiber_ins[length] = fiber_in
|
||||||
|
fiber_outs[length] = fiber_out
|
||||||
|
regens[length] = regen
|
||||||
|
timestampss[length] = timestamps
|
||||||
|
|
||||||
|
data = torch.zeros(2 * len(lengths), 2, fiber_out.shape[0])
|
||||||
|
channel_names = ["" for _ in range(2 * len(lengths))]
|
||||||
|
|
||||||
|
for li, length in enumerate(lengths):
|
||||||
|
data[2 * li, 0, :] = timestampss[length] / 128
|
||||||
|
data[2 * li, 1, :] = regens[length][:, 0].abs().square()
|
||||||
|
data[2 * li + 1, 0, :] = timestampss[length] / 128
|
||||||
|
data[2 * li + 1, 1, :] = regens[length][:, 1].abs().square()
|
||||||
|
|
||||||
|
channel_names[2 * li] = f"regen x {length}"
|
||||||
|
channel_names[2 * li + 1] = f"regen y {length}"
|
||||||
|
|
||||||
|
# get current backend
|
||||||
|
backend = matplotlib.get_backend()
|
||||||
|
|
||||||
|
matplotlib.use("TkCairo")
|
||||||
|
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
|
||||||
|
|
||||||
|
print_attrs = ("channel", "success", "min_area")
|
||||||
|
with np.printoptions(precision=3, suppress=True, formatter={'float': '{:0.3e}'.format}):
|
||||||
|
for result in eye.eye_stats:
|
||||||
|
print_dict = {attr: result[attr] for attr in print_attrs}
|
||||||
|
rprint(print_dict)
|
||||||
|
rprint()
|
||||||
|
|
||||||
|
eye.plot()
|
||||||
|
matplotlib.use(backend)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
# sweep_lengths(30000, 40000, 50000, 60000, 70000, model=".models/best_20241202_143149.tar")
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
global_settings=global_settings,
|
global_settings=global_settings,
|
||||||
data_settings=data_settings,
|
data_settings=data_settings,
|
||||||
pytorch_settings=pytorch_settings,
|
pytorch_settings=pytorch_settings,
|
||||||
model_settings=model_settings,
|
model_settings=model_settings,
|
||||||
optimizer_settings=optimizer_settings,
|
optimizer_settings=optimizer_settings,
|
||||||
checkpoint_path='.models/20241128_084935_8885.tar',
|
# checkpoint_path=".models/best_20241202_143149.tar",
|
||||||
settings_override={
|
# 20241202_143149
|
||||||
"model_settings": {
|
|
||||||
# "model_activation_func": "PowScale",
|
|
||||||
"dropout_prob": 0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
reset_epoch=True,
|
|
||||||
)
|
)
|
||||||
|
trainer.train()
|
||||||
best = trainer.train()
|
|
||||||
save_dict_to_file(best, ".models/best_results.json")
|
|
||||||
|
|
||||||
...
|
|
||||||
88
src/single-core-regen/sliced_dataset_test.py
Normal file
88
src/single-core-regen/sliced_dataset_test.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
# move into dir single-core-regen before running
|
||||||
|
|
||||||
|
from util.datasets import FiberRegenerationDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# def eye_dataset(dataset, no_symbols=None, offset=False, show=True):
|
||||||
|
# if no_symbols is None:
|
||||||
|
# no_symbols = len(dataset)
|
||||||
|
# _, axs = plt.subplots(2,2, sharex=True, sharey=True)
|
||||||
|
|
||||||
|
# xaxis = np.linspace(0,dataset.symbols_per_slice,dataset.samples_per_slice)
|
||||||
|
# roll = dataset.samples_per_symbol//2 if offset else 0
|
||||||
|
# for E_out, E_in in dataset[roll:dataset.samples_per_symbol*no_symbols+roll:dataset.samples_per_symbol]:
|
||||||
|
# E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1]
|
||||||
|
# axs[0,0].plot(xaxis, np.abs( E_in_x.numpy())**2, alpha=0.05, color='C0')
|
||||||
|
# axs[1,0].plot(xaxis, np.abs( E_in_y.numpy())**2, alpha=0.05, color='C0')
|
||||||
|
# axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2, alpha=0.05, color='C0')
|
||||||
|
# axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2, alpha=0.05, color='C0')
|
||||||
|
|
||||||
|
# if show:
|
||||||
|
# plt.show()
|
||||||
|
|
||||||
|
# # def plt_dataloader(dataloader, show=True):
|
||||||
|
# # _, axs = plt.subplots(2,2, sharex=True, sharey=True)
|
||||||
|
|
||||||
|
# # E_outs, E_ins = next(iter(dataloader))
|
||||||
|
# # for i, (E_out, E_in) in enumerate(zip(E_outs, E_ins)):
|
||||||
|
# # xaxis = np.linspace(dataset.symbols_per_slice*i,dataset.symbols_per_slice+dataset.symbols_per_slice*i,dataset.samples_per_slice)
|
||||||
|
# # E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1]
|
||||||
|
# # axs[0,0].plot(xaxis, np.abs(E_in_x.numpy())**2)
|
||||||
|
# # axs[1,0].plot(xaxis, np.abs(E_in_y.numpy())**2)
|
||||||
|
# # axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2)
|
||||||
|
# # axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2)
|
||||||
|
|
||||||
|
# # if show:
|
||||||
|
# # plt.show()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
dataset = FiberRegenerationDataset("data/20241115-175517-128-16384-10000-0-0-17-0-PAM4-0.ini", symbols=13, drop_first=100, output_dim=26, num_symbols=100)
|
||||||
|
|
||||||
|
loader = DataLoader(dataset, batch_size=10, shuffle=True)
|
||||||
|
|
||||||
|
x = []
|
||||||
|
y_fiber_in = []
|
||||||
|
y_fiber_out = []
|
||||||
|
|
||||||
|
for i, batch in enumerate(loader):
|
||||||
|
# if i > 128:
|
||||||
|
# break
|
||||||
|
|
||||||
|
fiber_in, fiber_out, timestamp = batch
|
||||||
|
|
||||||
|
fiber_out = fiber_out.reshape(fiber_out.shape[0], -1, 2)
|
||||||
|
fiber_out = fiber_out[:,fiber_out.shape[1]//2, :]
|
||||||
|
|
||||||
|
# input_data = input_data.reshape(-1,2)
|
||||||
|
# target = target.reshape(-1,2).squeeze()
|
||||||
|
# timestamp = timestamp.reshape(-1,1).squeeze()
|
||||||
|
|
||||||
|
x.append(timestamp.detach().numpy())
|
||||||
|
y_fiber_in.append(fiber_in.abs().square().detach().numpy())
|
||||||
|
y_fiber_out.append(fiber_out.abs().square().detach().numpy())
|
||||||
|
|
||||||
|
x = np.concat(x)
|
||||||
|
y_fiber_in = np.concat(y_fiber_in)
|
||||||
|
y_fiber_out = np.concat(y_fiber_out)
|
||||||
|
|
||||||
|
# order = np.argsort(x)
|
||||||
|
# x = x[order]
|
||||||
|
# y = y[order]
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(2,2, sharex=True, sharey=True)
|
||||||
|
axs[0,0].scatter((x/dataset.samples_per_symbol)%2, y_fiber_in[:,0], s=1, alpha=0.1)
|
||||||
|
axs[1,0].scatter((x/dataset.samples_per_symbol)%2, y_fiber_in[:,1], s=1, alpha=0.1)
|
||||||
|
axs[0,1].scatter((x/dataset.samples_per_symbol)%2, y_fiber_out[:,0], s=1, alpha=0.1)
|
||||||
|
axs[1,1].scatter((x/dataset.samples_per_symbol)%2, y_fiber_out[:,1], s=1, alpha=0.1)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# eye_dataset(dataset, 1000, offset=True, show=False)
|
||||||
|
|
||||||
|
# train_loader = DataLoader(dataset, batch_size=10, shuffle=False)
|
||||||
|
|
||||||
|
# plt_dataloader(train_loader, show=False)
|
||||||
|
|
||||||
|
# plt.show()
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
# move into dir single-core-regen before running
|
|
||||||
|
|
||||||
from util.dataset import SlicedDataset
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
def eye_dataset(dataset, no_symbols=None, offset=False, show=True):
|
|
||||||
if no_symbols is None:
|
|
||||||
no_symbols = len(dataset)
|
|
||||||
_, axs = plt.subplots(2,2, sharex=True, sharey=True)
|
|
||||||
|
|
||||||
xaxis = np.linspace(0,dataset.symbols_per_slice,dataset.samples_per_slice)
|
|
||||||
roll = dataset.samples_per_symbol//2 if offset else 0
|
|
||||||
for E_out, E_in in dataset[roll:dataset.samples_per_symbol*no_symbols+roll:dataset.samples_per_symbol]:
|
|
||||||
E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1]
|
|
||||||
axs[0,0].plot(xaxis, np.abs( E_in_x.numpy())**2, alpha=0.05, color='C0')
|
|
||||||
axs[1,0].plot(xaxis, np.abs( E_in_y.numpy())**2, alpha=0.05, color='C0')
|
|
||||||
axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2, alpha=0.05, color='C0')
|
|
||||||
axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2, alpha=0.05, color='C0')
|
|
||||||
|
|
||||||
if show:
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
# def plt_dataloader(dataloader, show=True):
|
|
||||||
# _, axs = plt.subplots(2,2, sharex=True, sharey=True)
|
|
||||||
|
|
||||||
# E_outs, E_ins = next(iter(dataloader))
|
|
||||||
# for i, (E_out, E_in) in enumerate(zip(E_outs, E_ins)):
|
|
||||||
# xaxis = np.linspace(dataset.symbols_per_slice*i,dataset.symbols_per_slice+dataset.symbols_per_slice*i,dataset.samples_per_slice)
|
|
||||||
# E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1]
|
|
||||||
# axs[0,0].plot(xaxis, np.abs(E_in_x.numpy())**2)
|
|
||||||
# axs[1,0].plot(xaxis, np.abs(E_in_y.numpy())**2)
|
|
||||||
# axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2)
|
|
||||||
# axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2)
|
|
||||||
|
|
||||||
# if show:
|
|
||||||
# plt.show()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
dataset = SlicedDataset("data/20241115-175517-128-16384-10000-0-0-17-0-PAM4-0.ini", symbols=1, drop_first=100)
|
|
||||||
print(dataset[0][0].shape)
|
|
||||||
|
|
||||||
eye_dataset(dataset, 1000, offset=True, show=False)
|
|
||||||
|
|
||||||
train_loader = DataLoader(dataset, batch_size=10, shuffle=False)
|
|
||||||
|
|
||||||
# plt_dataloader(train_loader, show=False)
|
|
||||||
|
|
||||||
plt.show()
|
|
||||||
@@ -17,3 +17,5 @@ from . import complexNN # noqa: F401
|
|||||||
# from .complexNN import complex_sse_loss # noqa: F401
|
# from .complexNN import complex_sse_loss # noqa: F401
|
||||||
|
|
||||||
from . import misc # noqa: F401
|
from . import misc # noqa: F401
|
||||||
|
|
||||||
|
from . import eye_diagram # noqa: F401
|
||||||
@@ -4,23 +4,36 @@ import torch.nn.functional as F
|
|||||||
# from torchlambertw.special import lambertw
|
# from torchlambertw.special import lambertw
|
||||||
|
|
||||||
|
|
||||||
def complex_mse_loss(input, target, power=False, reduction="mean"):
|
def complex_mse_loss(input, target, power=False, normalize=False, reduction="mean"):
|
||||||
"""
|
"""
|
||||||
Compute the mean squared error between two complex tensors.
|
Compute the mean squared error between two complex tensors.
|
||||||
If power is set to True, the loss is computed as |input|^2 - |target|^2
|
If power is set to True, the loss is computed as |input|^2 - |target|^2
|
||||||
"""
|
"""
|
||||||
reduce = getattr(torch, reduction)
|
reduce = getattr(torch, reduction)
|
||||||
|
power_penalty = 0
|
||||||
|
|
||||||
if power:
|
if power:
|
||||||
input = (input * input.conj()).real.to(dtype=input.dtype.to_real())
|
input = (input * input.conj()).real.to(dtype=input.dtype.to_real())
|
||||||
target = (target * target.conj()).real.to(dtype=target.dtype.to_real())
|
target = (target * target.conj()).real.to(dtype=target.dtype.to_real())
|
||||||
|
if normalize:
|
||||||
|
power_penalty = ((input.max() - input.min()) - (target.max() - target.min())) ** 2
|
||||||
|
power_penalty += (input.min() - target.min()) ** 2
|
||||||
|
input = input - input.min()
|
||||||
|
input = input / input.max()
|
||||||
|
target = target - target.min()
|
||||||
|
target = target / target.max()
|
||||||
|
else:
|
||||||
|
if normalize:
|
||||||
|
power_penalty = (input.abs().max() - target.abs().max()) ** 2
|
||||||
|
input = input / input.abs().max()
|
||||||
|
target = target / target.abs().max()
|
||||||
|
|
||||||
if input.is_complex() and target.is_complex():
|
if input.is_complex() and target.is_complex():
|
||||||
return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
|
return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag)) + power_penalty
|
||||||
elif input.is_complex() or target.is_complex():
|
elif input.is_complex() or target.is_complex():
|
||||||
raise ValueError("Input and target must have the same type (real or complex)")
|
raise ValueError("Input and target must have the same type (real or complex)")
|
||||||
else:
|
else:
|
||||||
return F.mse_loss(input, target, reduction=reduction)
|
return F.mse_loss(input, target, reduction=reduction) + power_penalty
|
||||||
|
|
||||||
|
|
||||||
def complex_sse_loss(input, target):
|
def complex_sse_loss(input, target):
|
||||||
@@ -53,14 +66,10 @@ class UnitaryLayer(nn.Module):
|
|||||||
return f"UnitaryLayer({self.in_features}, {self.out_features})"
|
return f"UnitaryLayer({self.in_features}, {self.out_features})"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class _Unitary(nn.Module):
|
class _Unitary(nn.Module):
|
||||||
def forward(self, X: torch.Tensor):
|
def forward(self, X: torch.Tensor):
|
||||||
if X.ndim < 2:
|
if X.ndim < 2:
|
||||||
raise ValueError(
|
raise ValueError(f"Only tensors with 2 or more dimensions are supported. Got a tensor of shape {X.shape}")
|
||||||
"Only tensors with 2 or more dimensions are supported. "
|
|
||||||
f"Got a tensor of shape {X.shape}"
|
|
||||||
)
|
|
||||||
n, k = X.size(-2), X.size(-1)
|
n, k = X.size(-2), X.size(-1)
|
||||||
transpose = n < k
|
transpose = n < k
|
||||||
if transpose:
|
if transpose:
|
||||||
@@ -80,6 +89,7 @@ class _Unitary(nn.Module):
|
|||||||
# X.copy_(q)
|
# X.copy_(q)
|
||||||
return q
|
return q
|
||||||
|
|
||||||
|
|
||||||
def unitary(module: nn.Module, name: str = "weight") -> nn.Module:
|
def unitary(module: nn.Module, name: str = "weight") -> nn.Module:
|
||||||
weight = getattr(module, name, None)
|
weight = getattr(module, name, None)
|
||||||
if not isinstance(weight, torch.Tensor):
|
if not isinstance(weight, torch.Tensor):
|
||||||
@@ -95,6 +105,7 @@ def unitary(module: nn.Module, name: str = "weight") -> nn.Module:
|
|||||||
nn.utils.parametrize.register_parametrization(module, name, unit)
|
nn.utils.parametrize.register_parametrization(module, name, unit)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
class _SpecialUnitary(nn.Module):
|
class _SpecialUnitary(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -108,6 +119,7 @@ class _SpecialUnitary(nn.Module):
|
|||||||
|
|
||||||
return q
|
return q
|
||||||
|
|
||||||
|
|
||||||
def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module:
|
def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module:
|
||||||
weight = getattr(module, name, None)
|
weight = getattr(module, name, None)
|
||||||
if not isinstance(weight, torch.Tensor):
|
if not isinstance(weight, torch.Tensor):
|
||||||
@@ -123,11 +135,13 @@ def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module:
|
|||||||
nn.utils.parametrize.register_parametrization(module, name, unit)
|
nn.utils.parametrize.register_parametrization(module, name, unit)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
class _Clamp(nn.Module):
|
class _Clamp(nn.Module):
|
||||||
def __init__(self, min, max):
|
def __init__(self, min, max):
|
||||||
super(_Clamp, self).__init__()
|
super(_Clamp, self).__init__()
|
||||||
self.min = min
|
self.min = min
|
||||||
self.max = max
|
self.max = max
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if x.is_complex():
|
if x.is_complex():
|
||||||
# clamp magnitude, ignore phase
|
# clamp magnitude, ignore phase
|
||||||
@@ -145,43 +159,29 @@ def clamp(module: nn.Module, name: str = "scale", min=0, max=1) -> nn.Module:
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
class ONNMiller(nn.Module):
|
class _EnergyConserving(nn.Module):
|
||||||
def __init__(self, input_dim, output_dim, dtype=None) -> None:
|
def __init__(self):
|
||||||
super(ONNMiller, self).__init__()
|
super(_EnergyConserving, self).__init__()
|
||||||
self.input_dim = input_dim
|
|
||||||
self.output_dim = output_dim
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
self.dim = max(input_dim, output_dim)
|
def forward(self, X: torch.Tensor):
|
||||||
|
if X.ndim == 2:
|
||||||
|
X = X.unsqueeze(0)
|
||||||
|
spectral_norm = torch.linalg.svdvals(X)[:, 0]
|
||||||
|
return (X / spectral_norm).squeeze()
|
||||||
|
|
||||||
# zero pad input to internal size if smaller
|
|
||||||
if self.input_dim < self.dim:
|
|
||||||
self.pad = lambda x: F.pad(x, ((self.dim - self.input_dim) // 2, (self.dim - self.input_dim + 1) // 2))
|
|
||||||
else:
|
|
||||||
self.pad = lambda x: x
|
|
||||||
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {self.dim}"
|
|
||||||
|
|
||||||
# crop output to desired size
|
def energy_conserving(module: nn.Module, name: str = "weight") -> nn.Module:
|
||||||
if self.output_dim < self.dim:
|
param = getattr(module, name, None)
|
||||||
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)]
|
if not isinstance(param, torch.Tensor):
|
||||||
else:
|
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
|
||||||
self.crop = lambda x: x
|
|
||||||
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
|
|
||||||
|
|
||||||
self.U = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
|
if not (2 <= param.ndim <= 3):
|
||||||
self.S = nn.Parameter(torch.randn(self.dim, dtype=self.dtype)) # -> parametrization: Clamp (magnitude 0..1)
|
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {param.ndim} dimensions.")
|
||||||
self.V = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
|
|
||||||
self.register_buffer("MZI_scale", torch.tensor(2, dtype=self.dtype.to_real()).sqrt())
|
unit = _EnergyConserving()
|
||||||
# V is actually V.H, but
|
nn.utils.parametrize.register_parametrization(module, name, unit)
|
||||||
|
return module
|
||||||
|
|
||||||
def forward(self, x_in):
|
|
||||||
x = x_in
|
|
||||||
x = self.pad(x)
|
|
||||||
x = x @ self.U
|
|
||||||
x = x * (self.S.squeeze() / self.MZI_scale)
|
|
||||||
x = x @ self.V
|
|
||||||
x = self.crop(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class ONN(nn.Module):
|
class ONN(nn.Module):
|
||||||
def __init__(self, input_dim, output_dim, dtype=None) -> None:
|
def __init__(self, input_dim, output_dim, dtype=None) -> None:
|
||||||
@@ -202,18 +202,21 @@ class ONN(nn.Module):
|
|||||||
|
|
||||||
# crop output to desired size
|
# crop output to desired size
|
||||||
if self.output_dim < self.dim:
|
if self.output_dim < self.dim:
|
||||||
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)]
|
self.crop = lambda x: x[
|
||||||
|
:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)
|
||||||
|
]
|
||||||
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
|
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
|
||||||
else:
|
else:
|
||||||
self.crop = lambda x: x
|
self.crop = lambda x: x
|
||||||
self.crop.__doc__ = f"Output size equals internal size {self.dim}"
|
self.crop.__doc__ = f"Output size equals internal size {self.dim}"
|
||||||
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype))
|
self.weight = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype))
|
||||||
|
# self.scale = nn.Parameter(torch.randn(1, dtype=self.dtype.to_real())+0.5)
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
q, _ = torch.linalg.qr(self.weight)
|
q, _ = torch.linalg.qr(self.weight)
|
||||||
self.weight.data = q
|
self.weight.data = q
|
||||||
|
|
||||||
# def get_M(self):
|
# def get_M(self):
|
||||||
# return self.U @ self.sigma @ self.V
|
# return self.U @ self.sigma @ self.V
|
||||||
|
|
||||||
@@ -221,37 +224,50 @@ class ONN(nn.Module):
|
|||||||
return self.crop(self.pad(x) @ self.weight)
|
return self.crop(self.pad(x) @ self.weight)
|
||||||
|
|
||||||
|
|
||||||
class SemiUnitaryLayer(nn.Module):
|
class ONNRect(nn.Module):
|
||||||
def __init__(self, input_dim, output_dim, dtype=None):
|
def __init__(self, input_dim, output_dim, square=False, dtype=None):
|
||||||
super(SemiUnitaryLayer, self).__init__()
|
super(ONNRect, self).__init__()
|
||||||
self.input_dim = input_dim
|
self.input_dim = input_dim
|
||||||
self.output_dim = output_dim
|
self.output_dim = output_dim
|
||||||
|
|
||||||
# Create a larger square matrix for QR decomposition
|
if square:
|
||||||
self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype))
|
dim = max(input_dim, output_dim)
|
||||||
self.scale = nn.Parameter(torch.tensor(1.0, dtype=dtype.to_real()))
|
self.weight = nn.Parameter(torch.randn(dim, dim, dtype=dtype))
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
# zero pad input to internal size if smaller
|
||||||
# Ensure the weights are unitary by QR decomposition
|
if self.input_dim < dim:
|
||||||
q, _ = torch.linalg.qr(self.weight)
|
self.pad = lambda x: F.pad(x, ((dim - self.input_dim) // 2, (dim - self.input_dim + 1) // 2))
|
||||||
# A = QR with A being a complex square matrix -> Q is unitary, R is upper triangular
|
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {dim}"
|
||||||
|
|
||||||
# truncate the matrix to the desired size
|
|
||||||
if self.input_dim > self.output_dim:
|
|
||||||
self.weight.data = q[: self.input_dim, : self.output_dim]
|
|
||||||
else:
|
else:
|
||||||
self.weight.data = q[: self.output_dim, : self.input_dim].t()
|
self.pad = lambda x: x
|
||||||
...
|
self.pad.__doc__ = f"Input size equals internal size {dim}"
|
||||||
|
|
||||||
|
# crop output to desired size
|
||||||
|
if self.output_dim < dim:
|
||||||
|
self.crop = lambda x: x[
|
||||||
|
:, (dim - self.output_dim) // 2 : (x.shape[1] - (dim - self.output_dim + 1) // 2)
|
||||||
|
]
|
||||||
|
self.crop.__doc__ = f"Crop output from {dim} to {self.output_dim}"
|
||||||
|
else:
|
||||||
|
self.crop = lambda x: x
|
||||||
|
self.crop.__doc__ = f"Output size equals internal size {dim}"
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.weight = nn.Parameter(torch.randn(output_dim, input_dim, dtype=dtype))
|
||||||
|
self.pad = lambda x: x
|
||||||
|
self.pad.__doc__ = "No padding"
|
||||||
|
self.crop = lambda x: x
|
||||||
|
self.crop.__doc__ = "No cropping"
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
with torch.no_grad():
|
x = self.pad(x)
|
||||||
scale = torch.clamp(self.scale, 0.0, 1.0)
|
out = self.crop((self.weight @ x.mT).mT)
|
||||||
out = torch.matmul(x, scale * self.weight)
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def __repr__(self):
|
# def __repr__(self):
|
||||||
return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})"
|
# return f"ONNRect({self.input_dim}, {self.output_dim})"
|
||||||
|
|
||||||
|
|
||||||
# class SaturableAbsorberLambertW(nn.Module):
|
# class SaturableAbsorberLambertW(nn.Module):
|
||||||
@@ -336,6 +352,19 @@ class DropoutComplex(nn.Module):
|
|||||||
return self.dropout(x)
|
return self.dropout(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Scale(nn.Module):
|
||||||
|
def __init__(self, size):
|
||||||
|
super(Scale, self).__init__()
|
||||||
|
self.size = size
|
||||||
|
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * self.scale
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Scale({self.size})"
|
||||||
|
|
||||||
|
|
||||||
class Identity(nn.Module):
|
class Identity(nn.Module):
|
||||||
"""
|
"""
|
||||||
implements the "activation" function
|
implements the "activation" function
|
||||||
@@ -348,6 +377,7 @@ class Identity(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class PowRot(nn.Module):
|
class PowRot(nn.Module):
|
||||||
def __init__(self, bias=False):
|
def __init__(self, bias=False):
|
||||||
super(PowRot, self).__init__()
|
super(PowRot, self).__init__()
|
||||||
@@ -363,11 +393,71 @@ class PowRot(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MZISingle(nn.Module):
|
||||||
|
def __init__(self, bias, size, func=None):
|
||||||
|
super(MZISingle, self).__init__()
|
||||||
|
self.omega = nn.Parameter(torch.randn(size))
|
||||||
|
self.phi = nn.Parameter(torch.randn(size))
|
||||||
|
self.func = func or (lambda x: x.abs().square()) # default to |z|^2
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
|
||||||
|
|
||||||
|
|
||||||
|
class EOActivation(nn.Module):
|
||||||
|
def __init__(self, bias, size=None):
|
||||||
|
# 10.1109/SiPhotonics60897.2024.10543376
|
||||||
|
super(EOActivation, self).__init__()
|
||||||
|
if size is None:
|
||||||
|
raise ValueError("Size must be specified")
|
||||||
|
self.size = size
|
||||||
|
self.alpha = nn.Parameter(torch.ones(size))
|
||||||
|
self.V_bias = nn.Parameter(torch.ones(size))
|
||||||
|
self.gain = nn.Parameter(torch.ones(size))
|
||||||
|
# if bias:
|
||||||
|
# self.phase_bias = nn.Parameter(torch.zeros(size))
|
||||||
|
# else:
|
||||||
|
# self.register_buffer("phase_bias", torch.zeros(size))
|
||||||
|
self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
|
||||||
|
self.register_buffer("responsivity", torch.ones(size)*0.9)
|
||||||
|
self.register_buffer("V_pi", torch.ones(size)*3)
|
||||||
|
|
||||||
|
self.reset_weights()
|
||||||
|
|
||||||
|
def reset_weights(self):
|
||||||
|
if "alpha" in self._parameters:
|
||||||
|
self.alpha.data = torch.ones(self.size)*0.5
|
||||||
|
if "V_pi" in self._parameters:
|
||||||
|
self.V_pi.data = torch.ones(self.size)*3
|
||||||
|
if "V_bias" in self._parameters:
|
||||||
|
self.V_bias.data = torch.zeros(self.size)
|
||||||
|
if "gain" in self._parameters:
|
||||||
|
self.gain.data = torch.ones(self.size)
|
||||||
|
if "responsivity" in self._parameters:
|
||||||
|
self.responsivity.data = torch.ones(self.size)*0.9
|
||||||
|
if "bias" in self._parameters:
|
||||||
|
self.phase_bias.data = torch.zeros(self.size)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
|
||||||
|
g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8)
|
||||||
|
intermediate = g_phi * x.abs().square() + phi_b
|
||||||
|
return (
|
||||||
|
1j
|
||||||
|
* torch.sqrt(1 - self.alpha)
|
||||||
|
* torch.exp(-0.5j * (intermediate + self.phase_bias))
|
||||||
|
* torch.cos(0.5 * intermediate)
|
||||||
|
* x
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Pow(nn.Module):
|
class Pow(nn.Module):
|
||||||
"""
|
"""
|
||||||
implements the activation function
|
implements the activation function
|
||||||
M(z) = ||z||^2 + b
|
M(z) = ||z||^2 + b
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, bias=False):
|
def __init__(self, bias=False):
|
||||||
super(Pow, self).__init__()
|
super(Pow, self).__init__()
|
||||||
if bias:
|
if bias:
|
||||||
@@ -375,7 +465,6 @@ class Pow(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.register_buffer("bias", torch.tensor(0.0))
|
self.register_buffer("bias", torch.tensor(0.0))
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
return x.abs().square().add(self.bias).to(dtype=x.dtype)
|
return x.abs().square().add(self.bias).to(dtype=x.dtype)
|
||||||
|
|
||||||
@@ -408,6 +497,7 @@ class MagScale(nn.Module):
|
|||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
return x.abs().add(self.bias).to(dtype=x.dtype).sin().mul(x)
|
return x.abs().add(self.bias).to(dtype=x.dtype).sin().mul(x)
|
||||||
|
|
||||||
|
|
||||||
class PowScale(nn.Module):
|
class PowScale(nn.Module):
|
||||||
def __init__(self, bias=False):
|
def __init__(self, bias=False):
|
||||||
super(PowScale, self).__init__()
|
super(PowScale, self).__init__()
|
||||||
@@ -486,10 +576,10 @@ __all__ = [
|
|||||||
complex_mse_loss,
|
complex_mse_loss,
|
||||||
UnitaryLayer,
|
UnitaryLayer,
|
||||||
unitary,
|
unitary,
|
||||||
|
energy_conserving,
|
||||||
clamp,
|
clamp,
|
||||||
ONN,
|
ONN,
|
||||||
ONNMiller,
|
ONNRect,
|
||||||
SemiUnitaryLayer,
|
|
||||||
DropoutComplex,
|
DropoutComplex,
|
||||||
Identity,
|
Identity,
|
||||||
Pow,
|
Pow,
|
||||||
@@ -498,6 +588,8 @@ __all__ = [
|
|||||||
ModReLU,
|
ModReLU,
|
||||||
CReLU,
|
CReLU,
|
||||||
ZReLU,
|
ZReLU,
|
||||||
|
MZISingle,
|
||||||
|
EOActivation,
|
||||||
# SaturableAbsorberLambertW,
|
# SaturableAbsorberLambertW,
|
||||||
# SaturableAbsorber,
|
# SaturableAbsorber,
|
||||||
# SpreadLayer,
|
# SpreadLayer,
|
||||||
|
|||||||
@@ -40,7 +40,8 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
|
|||||||
if symbols is None:
|
if symbols is None:
|
||||||
symbols = int(config["glova"]["nos"]) - skipfirst
|
symbols = int(config["glova"]["nos"]) - skipfirst
|
||||||
|
|
||||||
data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps]
|
data = np.load(datapath)[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)]
|
||||||
|
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps))
|
||||||
|
|
||||||
if normalize:
|
if normalize:
|
||||||
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
||||||
@@ -53,6 +54,8 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
|
|||||||
|
|
||||||
config["glova"]["nos"] = str(symbols)
|
config["glova"]["nos"] = str(symbols)
|
||||||
|
|
||||||
|
data = np.concatenate([data, timestamps.reshape(-1,1)], axis=-1)
|
||||||
|
|
||||||
data = torch.tensor(data, device=device, dtype=dtype)
|
data = torch.tensor(data, device=device, dtype=dtype)
|
||||||
|
|
||||||
return data, config
|
return data, config
|
||||||
@@ -100,7 +103,7 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
file_path: str | Path,
|
file_path: tuple | list | str | Path,
|
||||||
symbols: int | float,
|
symbols: int | float,
|
||||||
*,
|
*,
|
||||||
output_dim: int = None,
|
output_dim: int = None,
|
||||||
@@ -130,12 +133,12 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# check types
|
# check types
|
||||||
assert isinstance(file_path, str), "file_path must be a string"
|
assert isinstance(file_path, (str, Path, tuple, list)), "file_path must be a string, Path, tuple, or list"
|
||||||
assert isinstance(symbols, (float, int)), "symbols must be a float or an integer"
|
assert isinstance(symbols, (float, int)), "symbols must be a float or an integer"
|
||||||
assert output_dim is None or isinstance(output_dim, int), "output_len must be an integer"
|
assert output_dim is None or isinstance(output_dim, int), "output_len must be an integer"
|
||||||
assert isinstance(target_delay, (float, int)), "target_delay must be a float or an integer"
|
assert isinstance(target_delay, (float, int)), "target_delay must be a float or an integer"
|
||||||
assert isinstance(xy_delay, (float, int)), "xy_delay must be a float or an integer"
|
assert isinstance(xy_delay, (float, int)), "xy_delay must be a float or an integer"
|
||||||
assert isinstance(drop_first, int), "drop_first must be an integer"
|
# assert isinstance(drop_first, int), "drop_first must be an integer"
|
||||||
|
|
||||||
# check values
|
# check values
|
||||||
assert symbols > 0, "symbols must be positive"
|
assert symbols > 0, "symbols must be positive"
|
||||||
@@ -150,20 +153,38 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
dtype=np.complex128,
|
dtype=np.complex128,
|
||||||
)
|
)
|
||||||
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
|
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
|
||||||
|
timestamps = torch.arange(12800)
|
||||||
|
|
||||||
|
data_raw = torch.concatenate([data_raw, timestamps.reshape(-1, 1)], axis=-1)
|
||||||
|
|
||||||
self.config = {
|
self.config = {
|
||||||
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
|
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
|
||||||
"glova": {"sps": 128},
|
"glova": {"sps": 128},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
data_raw, self.config = load_data(
|
data_raw = None
|
||||||
|
self.config = None
|
||||||
|
files = []
|
||||||
|
for file_path in (file_path if isinstance(file_path, (tuple, list)) else [file_path]):
|
||||||
|
data, config = load_data(
|
||||||
file_path,
|
file_path,
|
||||||
skipfirst=drop_first,
|
skipfirst=drop_first,
|
||||||
symbols=kwargs.pop("num_symbols", None),
|
symbols=kwargs.get("num_symbols", None),
|
||||||
real=real,
|
real=real,
|
||||||
normalize=True,
|
normalize=True,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
if data_raw is None:
|
||||||
|
data_raw = data
|
||||||
|
else:
|
||||||
|
data_raw = torch.cat([data_raw, data], dim=0)
|
||||||
|
if self.config is None:
|
||||||
|
self.config = config
|
||||||
|
else:
|
||||||
|
assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same"
|
||||||
|
files.append(config["data"]["file"].strip('"'))
|
||||||
|
self.config["data"]["file"] = str(files)
|
||||||
|
|
||||||
self.device = data_raw.device
|
self.device = data_raw.device
|
||||||
|
|
||||||
@@ -190,10 +211,10 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
# data_raw = torch.tensor(data_raw, dtype=dtype)
|
# data_raw = torch.tensor(data_raw, dtype=dtype)
|
||||||
|
|
||||||
# data layout
|
# data layout
|
||||||
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0],
|
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0, timestamp0],
|
||||||
# [E_in_x1, E_in_y1, E_out_x1, E_out_y1],
|
# [E_in_x1, E_in_y1, E_out_x1, E_out_y1, timestamp1],
|
||||||
# ...
|
# ...
|
||||||
# [E_in_xN, E_in_yN, E_out_xN, E_out_yN] ]
|
# [E_in_xN, E_in_yN, E_out_xN, E_out_yN, timestampN] ]
|
||||||
|
|
||||||
data_raw = data_raw.transpose(0, 1)
|
data_raw = data_raw.transpose(0, 1)
|
||||||
|
|
||||||
@@ -201,16 +222,18 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
# [ E_in_x[0:N],
|
# [ E_in_x[0:N],
|
||||||
# E_in_y[0:N],
|
# E_in_y[0:N],
|
||||||
# E_out_x[0:N],
|
# E_out_x[0:N],
|
||||||
# E_out_y[0:N] ]
|
# E_out_y[0:N],
|
||||||
|
# timestamps[0:N] ]
|
||||||
|
|
||||||
# shift x data by xy_delay_samples relative to the y data (example value: 3)
|
# shift x data by xy_delay_samples relative to the y data (example value: 3)
|
||||||
# [ E_in_x [0:N], [ E_in_x [ 0:N ], [ E_in_x [3:N ],
|
# [ E_in_x [0:N], [ E_in_x [ 0:N ], [ E_in_x [3:N ],
|
||||||
# E_in_y [0:N], -> E_in_y [-3:N-3], -> E_in_y [0:N-3],
|
# E_in_y [0:N], -> E_in_y [-3:N-3], -> E_in_y [0:N-3],
|
||||||
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[3:N ],
|
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[3:N ],
|
||||||
# E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ]
|
# E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3],
|
||||||
|
# timestamps[0:N] ] timestamps[ 0:N ] ] timestamps[3:N ] ]
|
||||||
|
|
||||||
if self.xy_delay_samples != 0:
|
if self.xy_delay_samples != 0:
|
||||||
data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1)
|
data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples, 0], dim=1)
|
||||||
if self.xy_delay_samples > 0:
|
if self.xy_delay_samples > 0:
|
||||||
data_raw = data_raw[:, self.xy_delay_samples :]
|
data_raw = data_raw[:, self.xy_delay_samples :]
|
||||||
elif self.xy_delay_samples < 0:
|
elif self.xy_delay_samples < 0:
|
||||||
@@ -221,12 +244,13 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
# [ E_in_x [0:N], [ E_in_x [-5:N-5], [ E_in_x [0:N-5],
|
# [ E_in_x [0:N], [ E_in_x [-5:N-5], [ E_in_x [0:N-5],
|
||||||
# E_in_y [0:N], -> E_in_y [-5:N-5], -> E_in_y [0:N-5],
|
# E_in_y [0:N], -> E_in_y [-5:N-5], -> E_in_y [0:N-5],
|
||||||
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[5:N ],
|
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[5:N ],
|
||||||
# E_out_y[0:N] ] E_out_y[ 0:N ] ] E_out_y[5:N ] ]
|
# E_out_y[0:N] ] E_out_y[ 0:N ] ] E_out_y[5:N ],
|
||||||
|
# timestamps[0:N] ] timestamps[ 0:N ] ] timestamps[5:N ]
|
||||||
|
|
||||||
if self.target_delay_samples != 0:
|
if self.target_delay_samples != 0:
|
||||||
data_raw = roll_along(
|
data_raw = roll_along(
|
||||||
data_raw,
|
data_raw,
|
||||||
[self.target_delay_samples, self.target_delay_samples, 0, 0],
|
[self.target_delay_samples, self.target_delay_samples, 0, 0, 0],
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
if self.target_delay_samples > 0:
|
if self.target_delay_samples > 0:
|
||||||
@@ -234,21 +258,25 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
elif self.target_delay_samples < 0:
|
elif self.target_delay_samples < 0:
|
||||||
data_raw = data_raw[:, : self.target_delay_samples]
|
data_raw = data_raw[:, : self.target_delay_samples]
|
||||||
|
|
||||||
|
timestamps = data_raw[-1, :]
|
||||||
|
data_raw = data_raw[:-1, :]
|
||||||
data_raw = data_raw.view(2, 2, -1)
|
data_raw = data_raw.view(2, 2, -1)
|
||||||
|
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(dim=1)
|
||||||
|
data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
||||||
# data layout
|
# data layout
|
||||||
# [ [E_in_x, E_in_y],
|
# [ [E_in_x, E_in_y, timestamps],
|
||||||
# [E_out_x, E_out_y] ]
|
# [E_out_x, E_out_y, timestamps] ]
|
||||||
|
|
||||||
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||||
self.data = self.data.movedim(-2, 0)
|
self.data = self.data.movedim(-2, 0)
|
||||||
# -> [no_slices, 2, 2, samples_per_slice]
|
# -> [no_slices, 2, 3, samples_per_slice]
|
||||||
|
|
||||||
# data layout
|
# data layout
|
||||||
# [
|
# [
|
||||||
# [ [E_in_x[0:N+0], E_in_y[0:N+0] ], [ E_out_x[0:N+0], E_out_y[0:N+0] ] ],
|
# [ [E_in_x[0:N+0], E_in_y[0:N+0], timestamps[0:N+0]], [ E_out_x[0:N+0], E_out_y[0:N+0], timestamps[0:N+0] ] ],
|
||||||
# [ [E_in_x[1:N+1], E_in_y[1:N+1] ], [ E_out_x[1:N+1], E_out_y[1:N+1] ] ],
|
# [ [E_in_x[1:N+1], E_in_y[1:N+1], timestamps[1:N+1]], [ E_out_x[1:N+1], E_out_y[1:N+1], timestamps[1:N+1] ] ],
|
||||||
# ...
|
# ...
|
||||||
# ] -> [no_slices, 2, 2, samples_per_slice]
|
# ] -> [no_slices, 2, 3, samples_per_slice]
|
||||||
|
|
||||||
...
|
...
|
||||||
|
|
||||||
@@ -259,24 +287,24 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
if isinstance(idx, slice):
|
if isinstance(idx, slice):
|
||||||
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
|
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
|
||||||
else:
|
else:
|
||||||
data, target = self.data[idx, 1].squeeze(), self.data[idx, 0].squeeze()
|
data_slice = self.data[idx].squeeze()
|
||||||
|
|
||||||
# reduce by by taking self.output_dim equally spaced samples
|
data_slice = data_slice[:, :, :data_slice.shape[2] // self.output_dim * self.output_dim]
|
||||||
data = data[:, : data.shape[1] // self.output_dim * self.output_dim]
|
|
||||||
data = data.view(data.shape[0], self.output_dim, -1)
|
data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
|
||||||
data = data[:, :, 0]
|
|
||||||
|
target = data_slice[0, :, self.output_dim//2, 0]
|
||||||
|
data = data_slice[1, :, :, 0]
|
||||||
|
|
||||||
|
# data_timestamps = data[-1,:].real
|
||||||
|
data = data[:-1, :]
|
||||||
|
target_timestamp = target[-1].real
|
||||||
|
target = target[:-1]
|
||||||
|
|
||||||
# target is corresponding to the middle of the data as the output sample is influenced by the data before and after it
|
|
||||||
target = target[:, : target.shape[1] // self.output_dim * self.output_dim]
|
|
||||||
target = target.view(target.shape[0], self.output_dim, -1)
|
|
||||||
target = target[:, 0, target.shape[2] // 2]
|
|
||||||
|
|
||||||
data = data.transpose(0, 1).flatten().squeeze()
|
data = data.transpose(0, 1).flatten().squeeze()
|
||||||
|
# data_timestamps = data_timestamps.flatten().squeeze()
|
||||||
target = target.flatten().squeeze()
|
target = target.flatten().squeeze()
|
||||||
|
target_timestamp = target_timestamp.flatten().squeeze()
|
||||||
|
|
||||||
# data layout:
|
return data, target, target_timestamp
|
||||||
# [sample_x0, sample_y0, sample_x1, sample_y1, ...]
|
|
||||||
# target layout:
|
|
||||||
# [sample_x0, sample_y0]
|
|
||||||
|
|
||||||
return data, target
|
|
||||||
|
|||||||
418
src/single-core-regen/util/eye_diagram.py
Normal file
418
src/single-core-regen/util/eye_diagram.py
Normal file
@@ -0,0 +1,418 @@
|
|||||||
|
from matplotlib import pyplot as plt
|
||||||
|
from matplotlib.colors import LinearSegmentedColormap
|
||||||
|
import numpy as np
|
||||||
|
from scipy.cluster.vq import kmeans2
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from rich.traceback import install
|
||||||
|
from rich import pretty
|
||||||
|
from rich import print
|
||||||
|
|
||||||
|
install()
|
||||||
|
pretty.install()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
|
||||||
|
data = create_symbol_sequence(n_symbols, skew=skew)
|
||||||
|
signal = generate_signal(data, sps)
|
||||||
|
signal = normalization_with_noise(signal, noise)
|
||||||
|
|
||||||
|
xaxis = np.arange(0, len(signal)) / sps
|
||||||
|
return np.vstack([xaxis, signal])
|
||||||
|
|
||||||
|
def create_symbol_sequence(n_symbols, skew=1):
|
||||||
|
np.random.seed(42)
|
||||||
|
data = np.random.randint(0, 4, n_symbols) / 4
|
||||||
|
data = np.pow(data, skew)
|
||||||
|
return tuple(data)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_signal(data, sps):
|
||||||
|
working_data = np.diff(data, prepend=data[0])
|
||||||
|
data_padded = np.zeros(len(data) * sps)
|
||||||
|
data_padded[::sps] = working_data
|
||||||
|
data_padded = np.pad(data_padded, (0, sps // 2), mode="constant")
|
||||||
|
|
||||||
|
wavelet = generate_wavelet(sps, oversample=3)
|
||||||
|
|
||||||
|
signal = np.convolve(data_padded, wavelet)
|
||||||
|
signal = np.cumsum(signal)
|
||||||
|
signal = signal[sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
||||||
|
|
||||||
|
return signal
|
||||||
|
|
||||||
|
|
||||||
|
def normalization_with_noise(signal, noise=0):
|
||||||
|
if noise > 0:
|
||||||
|
awgn = np.random.normal(0, noise * (np.max(signal) - np.min(signal)), len(signal))
|
||||||
|
signal += awgn
|
||||||
|
|
||||||
|
# min-max normalization
|
||||||
|
signal = signal - np.min(signal)
|
||||||
|
signal = signal / np.max(signal)
|
||||||
|
return signal
|
||||||
|
|
||||||
|
|
||||||
|
def generate_wavelet(sps, oversample=3):
|
||||||
|
sample_points = np.linspace(
|
||||||
|
-oversample * sps,
|
||||||
|
oversample * sps,
|
||||||
|
2 * oversample * sps,
|
||||||
|
endpoint=True,
|
||||||
|
)
|
||||||
|
sigma = 0.33 / (1 * np.sqrt(2 * np.log(2))) * sps
|
||||||
|
pulse = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
|
||||||
|
|
||||||
|
return pulse
|
||||||
|
|
||||||
|
|
||||||
|
class eye_diagram:
|
||||||
|
def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4):
|
||||||
|
# data has shape [channels, 2, samples]
|
||||||
|
# each sample has a timestamp and a value
|
||||||
|
if data.ndim == 2:
|
||||||
|
data = data[np.newaxis, :, :]
|
||||||
|
self.channel_names = channel_names
|
||||||
|
self.raw_data = data
|
||||||
|
self.channels = data.shape[0]
|
||||||
|
self.n_levels = n_levels
|
||||||
|
self.eye_stats = [{"success": False} for _ in range(self.channels)]
|
||||||
|
self.horizontal_bins = horizontal_bins
|
||||||
|
self.vertical_bins = vertical_bins
|
||||||
|
self.eye_built = False
|
||||||
|
self.analyse(self.n_levels)
|
||||||
|
|
||||||
|
def generate_eye_data(self):
|
||||||
|
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
|
||||||
|
self.y_bins = np.zeros((self.channels, self.vertical_bins))
|
||||||
|
self.eye_data = np.zeros((self.channels, self.vertical_bins, self.horizontal_bins))
|
||||||
|
for i in range(self.channels):
|
||||||
|
data_min = np.min(self.raw_data[i, 1, :])
|
||||||
|
data_max = np.max(self.raw_data[i, 1, :])
|
||||||
|
self.y_bins[i] = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
|
||||||
|
|
||||||
|
t_vals = self.raw_data[i, 0, :] % 2
|
||||||
|
val_vals = self.raw_data[i, 1, :]
|
||||||
|
|
||||||
|
x_indices = np.digitize(t_vals, self.x_bins) - 1
|
||||||
|
y_indices = np.digitize(val_vals, self.y_bins[i]) - 1
|
||||||
|
|
||||||
|
np.add.at(self.eye_data[i], (y_indices, x_indices), 1)
|
||||||
|
self.eye_built = True
|
||||||
|
|
||||||
|
def plot(self, title="Eye Diagram", stats=True, show=True):
|
||||||
|
if not self.eye_built:
|
||||||
|
self.generate_eye_data()
|
||||||
|
cmap = LinearSegmentedColormap.from_list(
|
||||||
|
"eyemap",
|
||||||
|
[(0, "white"), (0.1, "blue"), (0.2, "cyan"), (0.5, "green"), (0.8, "yellow"), (0.9, "red"), (1, "magenta")],
|
||||||
|
)
|
||||||
|
if self.channels % 2 == 0:
|
||||||
|
rows = 2
|
||||||
|
cols = self.channels // 2
|
||||||
|
else:
|
||||||
|
cols = int(np.ceil(np.sqrt(self.channels)))
|
||||||
|
rows = int(np.ceil(self.channels / cols))
|
||||||
|
fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False)
|
||||||
|
fig.suptitle(title)
|
||||||
|
ax = np.atleast_1d(ax).transpose().flatten()
|
||||||
|
for i in range(self.channels):
|
||||||
|
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}")
|
||||||
|
ax[i].set_xlabel("Symbol")
|
||||||
|
ax[i].set_ylabel("Amplitude")
|
||||||
|
ax[i].grid()
|
||||||
|
ax[i].imshow(
|
||||||
|
self.eye_data[i],
|
||||||
|
origin="lower",
|
||||||
|
aspect="auto",
|
||||||
|
cmap=cmap,
|
||||||
|
extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
|
||||||
|
)
|
||||||
|
ax[i].set_xlim((self.x_bins[0], self.x_bins[-1]))
|
||||||
|
ymin = np.min(self.y_bins[:, 0])
|
||||||
|
ymax = np.max(self.y_bins[:, -1])
|
||||||
|
yspan = ymax - ymin
|
||||||
|
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
|
||||||
|
if stats and self.eye_stats[i]["success"]:
|
||||||
|
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
|
||||||
|
ax[i].set_yticks(self.eye_stats[i]["levels"])
|
||||||
|
# add arrows for amplitudes
|
||||||
|
for j in range(len(self.eye_stats[i]["amplitudes"])):
|
||||||
|
ax[i].annotate(
|
||||||
|
"",
|
||||||
|
xy=(0.05, self.eye_stats[i]["levels"][j]),
|
||||||
|
xytext=(0.05, self.eye_stats[i]["levels"][j + 1]),
|
||||||
|
arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
||||||
|
)
|
||||||
|
ax[i].annotate(
|
||||||
|
f"{self.eye_stats[i]['amplitudes'][j]:.2e}",
|
||||||
|
xy=(0.06, (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2),
|
||||||
|
)
|
||||||
|
# add arrows for eye heights
|
||||||
|
for j in range(len(self.eye_stats[i]["heights"])):
|
||||||
|
try:
|
||||||
|
bot = np.max(self.eye_stats[i]["amplitude_clusters"][j])
|
||||||
|
top = np.min(self.eye_stats[i]["amplitude_clusters"][j + 1])
|
||||||
|
|
||||||
|
ax[i].annotate(
|
||||||
|
"",
|
||||||
|
xy=(self.eye_stats[i]["time_midpoint"], bot),
|
||||||
|
xytext=(self.eye_stats[i]["time_midpoint"], top),
|
||||||
|
arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
||||||
|
)
|
||||||
|
ax[i].annotate(
|
||||||
|
f"{self.eye_stats[i]['heights'][j]:.2e}",
|
||||||
|
xy=(self.eye_stats[i]["time_midpoint"] + 0.015, (bot + top) / 2 + 0.04),
|
||||||
|
)
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
# add arrows for eye widths
|
||||||
|
for j in range(len(self.eye_stats[i]["widths"])):
|
||||||
|
try:
|
||||||
|
left = np.max(self.eye_stats[i]["time_clusters"][j][0])
|
||||||
|
right = np.min(self.eye_stats[i]["time_clusters"][j][1])
|
||||||
|
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
||||||
|
|
||||||
|
ax[i].annotate(
|
||||||
|
"",
|
||||||
|
xy=(left, vertical),
|
||||||
|
xytext=(right, vertical),
|
||||||
|
arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
||||||
|
)
|
||||||
|
ax[i].annotate(
|
||||||
|
f"{self.eye_stats[i]['widths'][j]:.2e}",
|
||||||
|
xy=((left + right) / 2 - 0.15, vertical + 0.01),
|
||||||
|
)
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# add area
|
||||||
|
for j in range(len(self.eye_stats[i]["areas"])):
|
||||||
|
horizontal = self.eye_stats[i]["time_midpoint"]
|
||||||
|
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
||||||
|
ax[i].annotate(
|
||||||
|
f"{self.eye_stats[i]['areas'][j]:.2e}",
|
||||||
|
xy=(horizontal + 0.035, vertical - 0.07),
|
||||||
|
)
|
||||||
|
|
||||||
|
# add min_area above the plot
|
||||||
|
ax[i].annotate(
|
||||||
|
f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
|
||||||
|
xy=(0.05, ymax + 0.05 * yspan),
|
||||||
|
# xycoords="axes fraction",
|
||||||
|
ha="left",
|
||||||
|
va="center",
|
||||||
|
)
|
||||||
|
fig.tight_layout()
|
||||||
|
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def analyse(self, n_levels=4):
|
||||||
|
warnings.filterwarnings("error")
|
||||||
|
for i in range(self.channels):
|
||||||
|
self.eye_stats[i]["channel"] = str(i+1) if self.channel_names is None else self.channel_names[i]
|
||||||
|
try:
|
||||||
|
approx_levels = eye_diagram.approximate_levels(self.raw_data[i], n_levels)
|
||||||
|
|
||||||
|
time_bounds = eye_diagram.calculate_time_bounds(self.raw_data[i], approx_levels)
|
||||||
|
|
||||||
|
self.eye_stats[i]["time_midpoint"] = (time_bounds[0] + time_bounds[1]) / 2
|
||||||
|
|
||||||
|
self.eye_stats[i]["levels"], self.eye_stats[i]["amplitude_clusters"] = eye_diagram.calculate_levels(
|
||||||
|
self.raw_data[i], approx_levels, time_bounds
|
||||||
|
)
|
||||||
|
|
||||||
|
self.eye_stats[i]["amplitudes"] = np.diff(self.eye_stats[i]["levels"])
|
||||||
|
|
||||||
|
self.eye_stats[i]["heights"] = eye_diagram.calculate_eye_heights(
|
||||||
|
self.eye_stats[i]["amplitude_clusters"]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.eye_stats[i]["widths"], self.eye_stats[i]["time_clusters"] = eye_diagram.calculate_eye_widths(
|
||||||
|
self.raw_data[i], self.eye_stats[i]["levels"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# # check if time clusters are valid (upper bound > time_midpoint > lower bound)
|
||||||
|
# # if not: raise ValueError
|
||||||
|
# for j in range(len(self.eye_stats[i]['time_clusters'])):
|
||||||
|
# if not (np.max(self.eye_stats[i]['time_clusters'][j][0]) < self.eye_stats[i]["time_midpoint"] < np.min(self.eye_stats[i]['time_clusters'][j][1])):
|
||||||
|
# raise ValueError
|
||||||
|
|
||||||
|
self.eye_stats[i]["areas"] = self.eye_stats[i]["heights"] * self.eye_stats[i]["widths"]
|
||||||
|
self.eye_stats[i]["mean_area"] = np.mean(self.eye_stats[i]["areas"])
|
||||||
|
self.eye_stats[i]["min_area"] = np.min(self.eye_stats[i]["areas"])
|
||||||
|
|
||||||
|
self.eye_stats[i]["success"] = True
|
||||||
|
except (RuntimeWarning, UserWarning, ValueError):
|
||||||
|
self.eye_stats[i]["success"] = False
|
||||||
|
self.eye_stats[i]["time_midpoint"] = 0
|
||||||
|
self.eye_stats[i]["levels"] = np.zeros(n_levels)
|
||||||
|
self.eye_stats[i]["amplitude_clusters"] = []
|
||||||
|
self.eye_stats[i]["amplitudes"] = np.zeros(n_levels - 1)
|
||||||
|
self.eye_stats[i]["heights"] = np.zeros(n_levels - 1)
|
||||||
|
self.eye_stats[i]["widths"] = np.zeros(n_levels - 1)
|
||||||
|
self.eye_stats[i]["areas"] = np.zeros(n_levels - 1)
|
||||||
|
self.eye_stats[i]["mean_area"] = 0
|
||||||
|
self.eye_stats[i]["min_area"] = 0
|
||||||
|
|
||||||
|
warnings.resetwarnings()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def approximate_levels(data, levels):
|
||||||
|
amplitudes = data[1]
|
||||||
|
grouping_data = amplitudes.reshape(-1, 1)
|
||||||
|
|
||||||
|
kmeans, clusters = eye_diagram.kmeans_cluster(grouping_data, levels)
|
||||||
|
|
||||||
|
centroids = np.zeros(levels)
|
||||||
|
for i in range(levels):
|
||||||
|
centroids[i] = eye_diagram.shorth(clusters[i])
|
||||||
|
|
||||||
|
return np.sort(centroids)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def kmeans_cluster(data, levels):
|
||||||
|
working_data = data.reshape(-1, 1)
|
||||||
|
# initial = np.linspace(np.min(working_data), np.max(working_data), levels).reshape(-1, 1)
|
||||||
|
kmeans = kmeans2(working_data, levels, iter=100, minit="++")
|
||||||
|
|
||||||
|
order = np.argsort(kmeans[0].squeeze())
|
||||||
|
kmeans[0][:] = kmeans[0][order]
|
||||||
|
order = np.argsort(order)
|
||||||
|
kmeans[1][:] = order[kmeans[1]]
|
||||||
|
|
||||||
|
clusters = [[] for _ in range(levels)]
|
||||||
|
for i, elem in enumerate(data):
|
||||||
|
clusters[kmeans[1][i]].append(elem.squeeze())
|
||||||
|
clusters = [np.array(cluster) for cluster in clusters]
|
||||||
|
|
||||||
|
# clusters = [clusters[i] for i in order]
|
||||||
|
|
||||||
|
return kmeans, clusters
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def shorth(data):
|
||||||
|
working_data = np.sort(data)
|
||||||
|
n = len(working_data)
|
||||||
|
h = n // 2 + 1
|
||||||
|
min_diff = np.inf
|
||||||
|
interval = np.zeros(2)
|
||||||
|
for i in range(n - h):
|
||||||
|
diff = working_data[i + h] - working_data[i]
|
||||||
|
if diff < min_diff:
|
||||||
|
min_diff = diff
|
||||||
|
interval = [working_data[i], working_data[i + h]]
|
||||||
|
return np.mean(interval)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_time_bounds(data, level_centroids):
|
||||||
|
n_levels = 2
|
||||||
|
|
||||||
|
# prepare data
|
||||||
|
selection_range = eye_diagram.calc_selection_range(level_centroids[1:3], 0.01)
|
||||||
|
|
||||||
|
# times = np.arange(0, len(data), dtype=np.float32)
|
||||||
|
times, amplitudes = data
|
||||||
|
grouping_data = times[(amplitudes > selection_range[0]) & (amplitudes < selection_range[1])]
|
||||||
|
grouping_data = grouping_data % 2
|
||||||
|
grouping_data = grouping_data.reshape(-1, 1)
|
||||||
|
|
||||||
|
kmeans, clusters = eye_diagram.kmeans_cluster(grouping_data, n_levels)
|
||||||
|
|
||||||
|
# time_midpoint = (np.min(clusters[1]) + np.max(clusters[0]))/2
|
||||||
|
|
||||||
|
# # check if time clusters are valid (upper bound > time_midpoint > lower bound)
|
||||||
|
# # if not: raise ValueError
|
||||||
|
# if not (np.max(clusters[0]) < time_midpoint < np.min(clusters[1])):
|
||||||
|
# raise ValueError
|
||||||
|
|
||||||
|
return np.min(clusters[1]), np.max(clusters[0])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calc_selection_range(data, tolerance):
|
||||||
|
middle = np.mean(data)
|
||||||
|
tol = tolerance * np.abs(np.diff(data))
|
||||||
|
return (middle - tol, middle + tol)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_levels(data, level_centroids, time_bounds):
|
||||||
|
selection_range = eye_diagram.calc_selection_range(time_bounds, 0.025)
|
||||||
|
|
||||||
|
times, amplitudes = data
|
||||||
|
indices = np.arange(0, len(times))
|
||||||
|
filtered_time = indices[((times % 2) > selection_range[0]) & ((times % 2) < selection_range[1])]
|
||||||
|
filtered_data = amplitudes[filtered_time]
|
||||||
|
|
||||||
|
vertical_bounds = np.array([
|
||||||
|
-np.inf,
|
||||||
|
*[(level_centroids[i] + level_centroids[i + 1]) / 2 for i in range(len(level_centroids) - 1)],
|
||||||
|
np.inf,
|
||||||
|
])
|
||||||
|
|
||||||
|
central_level_means = np.zeros(len(level_centroids))
|
||||||
|
amplitude_clusters = []
|
||||||
|
for i in range(len(level_centroids)):
|
||||||
|
amplitude_filtered_data = filtered_data[
|
||||||
|
(filtered_data > vertical_bounds[i]) & (filtered_data < vertical_bounds[i + 1])
|
||||||
|
]
|
||||||
|
amplitude_clusters.append(amplitude_filtered_data)
|
||||||
|
central_level_means[i] = np.mean(amplitude_filtered_data)
|
||||||
|
|
||||||
|
# # check if amplitude clusters are valid (upper bound > level_midpoint > lower bound)
|
||||||
|
# # if not: raise ValueError
|
||||||
|
# for j in range(len(amplitude_clusters)):
|
||||||
|
# level_midpoint = (central_level_means[j] + central_level_means[j+1]) / 2
|
||||||
|
# if not (np.max(amplitude_clusters[0]) < level_midpoint < np.min(amplitude_clusters[1])):
|
||||||
|
# raise ValueError
|
||||||
|
|
||||||
|
return central_level_means, amplitude_clusters
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_eye_heights(amplitude_clusters):
|
||||||
|
eye_heights = np.zeros(len(amplitude_clusters) - 1)
|
||||||
|
for i in range(len(amplitude_clusters) - 1):
|
||||||
|
eye_heights[i] = np.min(amplitude_clusters[i + 1]) - np.max(amplitude_clusters[i])
|
||||||
|
return eye_heights
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_eye_widths(data, central_level_means):
|
||||||
|
n_levels = len(central_level_means)
|
||||||
|
|
||||||
|
widths = np.zeros(n_levels - 1)
|
||||||
|
|
||||||
|
times, amplitudes = data
|
||||||
|
clusters = []
|
||||||
|
for i in range(n_levels - 1):
|
||||||
|
selection_range = eye_diagram.calc_selection_range(
|
||||||
|
[central_level_means[i], central_level_means[i + 1]], 0.01
|
||||||
|
)
|
||||||
|
grouping_data = times[(amplitudes > selection_range[0]) & (amplitudes < selection_range[1])]
|
||||||
|
grouping_data = grouping_data % 2
|
||||||
|
grouping_data = grouping_data.reshape(-1, 1)
|
||||||
|
kmeans, cluster = eye_diagram.kmeans_cluster(grouping_data, 2)
|
||||||
|
clusters.append(cluster)
|
||||||
|
widths[i] = np.min(cluster[1]) - np.max(cluster[0])
|
||||||
|
...
|
||||||
|
return widths, clusters
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
length = int(2**14)
|
||||||
|
# data = generate_sample_data(length, noise=1)
|
||||||
|
# data1 = generate_sample_data(length, noise=0.01)
|
||||||
|
# data2 = generate_sample_data(length, noise=0.01, skew=1.2)
|
||||||
|
# data3 = generate_sample_data(length, noise=0.02)
|
||||||
|
|
||||||
|
# data = np.stack([data, data1, data2, data3])
|
||||||
|
|
||||||
|
data = generate_sample_data(length, noise=0.005)
|
||||||
|
eye = eye_diagram(data, horizontal_bins=256, vertical_bins=256)
|
||||||
|
attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths", "area", "mean_area", "min_area")
|
||||||
|
for i, channel in enumerate(eye.eye_stats):
|
||||||
|
print(f"Channel {i}")
|
||||||
|
print_data = {attr: channel[attr] for attr in attrs}
|
||||||
|
print(print_data)
|
||||||
|
|
||||||
|
eye.plot()
|
||||||
Reference in New Issue
Block a user