Compare commits

..

6 Commits

12 changed files with 1210 additions and 350 deletions

1
.gitmodules vendored
View File

@@ -1,3 +1,4 @@
[submodule "pypho"]
path = pypho
url = git@gitlab.lrz.de:000000003B9B3E61/pypho.git
branch = main

View File

@@ -33,7 +33,7 @@ flags = "FFTW_PATIENT"
nthreads = 32
[fiber]
length = 80000
length = 10000
gamma = 1.14
alpha = 0.2
D = 17
@@ -201,7 +201,7 @@ def initialize_fiber_and_data(config, input_data_override=None):
"jitter_seed", (int(time.time() * 1000)) % 2**32
)
symbolsrc = pypho.symbols(
py_glova, py_glova.nos, pattern="ones", seed=config["signal"]["seed"]
py_glova, py_glova.nos, pattern="random", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
)
laser = pypho.lasmod(
py_glova, power=config["signal"]["laser_power"]+1.5, Df=0, theta=np.pi / 4
@@ -214,8 +214,8 @@ def initialize_fiber_and_data(config, input_data_override=None):
seed=config["signal"]["jitter_seed"],
)
symbols_x = symbolsrc(pattern="random", p1=config["signal"]["mod_order"])
symbols_y = symbolsrc(pattern="random", p1=config["signal"]["mod_order"])
symbols_x = symbolsrc()
symbols_y = symbolsrc()
symbols_x[:3] = 0
symbols_y[:3] = 0

View File

@@ -258,12 +258,12 @@ class HyperTraining:
f"model_hidden_dim_{i}",
self.model_settings.n_hidden_nodes,
)
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype))
layers.append(util.complexNN.ONNRect(last_dim, hidden_dim, dtype=dtype))
last_dim = hidden_dim
layers.append(getattr(util.complexNN, afunc)())
n_nodes += last_dim
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, self.model_settings.output_dim, dtype=dtype))
layers.append(util.complexNN.ONNRect(last_dim, self.model_settings.output_dim, dtype=dtype))
model = nn.Sequential(*layers)

View File

@@ -11,7 +11,7 @@ class GlobalSettings:
# data settings
@dataclass
class DataSettings:
config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
config_path: tuple | list | str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
dtype: tuple = ("complex64", "float64")
symbols: tuple | float | int = 8
output_size: tuple | float | int = 64
@@ -39,7 +39,7 @@ class PytorchSettings:
summary_dir: str = ".runs"
write_every: int = 10
head_symbols: int = 40
eye_symbols: int = 400
eye_symbols: int = 1000
# model settings
@@ -52,13 +52,16 @@ class ModelSettings:
overrides: dict = field(default_factory=dict)
dropout_prob: float | None = None
model_layer_function: str | None = None
scale: bool = False
model_layer_kwargs: dict | None = None
model_layer_parametrizations: list= field(default_factory=list)
@dataclass
class OptimizerSettings:
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
learning_rate: tuple | float = (1e-5, 1e-1)
optimizer_kwargs: dict | None = None
# learning_rate: tuple | float = (1e-5, 1e-1)
scheduler: str | None = None
scheduler_kwargs: dict | None = None

View File

@@ -1,8 +1,10 @@
import copy
from datetime import datetime
from pathlib import Path
import random
from typing import Literal
import matplotlib
from matplotlib.colors import LinearSegmentedColormap
import torch.nn.utils.parametrize
try:
@@ -50,6 +52,7 @@ class regenerator(nn.Module):
self,
*dims,
layer_function=util.complexNN.ONN,
layer_kwargs: dict | None = None,
layer_parametrizations: list[dict] = None,
# [
# {
@@ -64,6 +67,7 @@ class regenerator(nn.Module):
activation_function=util.complexNN.Pow,
dtype=torch.float64,
dropout_prob=0.01,
scale=False,
**kwargs,
):
super(regenerator, self).__init__()
@@ -74,39 +78,57 @@ class regenerator(nn.Module):
raise ValueError("dims must be provided")
self._n_hidden_layers = len(dims) - 2
self._layers = nn.Sequential()
if layer_kwargs is None:
layer_kwargs = {}
# self.powers = []
for i in range(self._n_hidden_layers + 1):
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype))
if scale:
self._layers.append(util.complexNN.Scale(dims[i]))
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs))
if i < self._n_hidden_layers:
if dropout_prob is not None:
self._layers.append(util.complexNN.DropoutComplex(p=dropout_prob))
self._layers.append(activation_function())
self._layers.append(activation_function(bias=True, size=dims[i + 1]))
self._layers.append(util.complexNN.Scale(dims[-1]))
# add parametrizations
if layer_parametrizations is not None:
for layer in self._layers:
for layer_parametrization in layer_parametrizations:
tensor_name = layer_parametrization.get("tensor_name", None)
parametrization = layer_parametrization.get("parametrization", None)
param_kwargs = layer_parametrization.get("kwargs", {})
if (
tensor_name is not None
and tensor_name in self._layers[-1]._parameters
and parametrization is not None
):
parametrization(self._layers[-1], tensor_name, **param_kwargs)
if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None:
parametrization(layer, tensor_name, **param_kwargs)
def forward(self, input_x):
# def __call__(self, input_x, **kwargs):
# return self.forward(input_x, **kwargs)
def forward(self, input_x, trace_powers=False):
x = input_x
if trace_powers:
powers = [x.abs().square().sum()]
# check if tracing
if torch.jit.is_tracing():
for layer in self._layers:
x = layer(x)
if trace_powers:
powers.append(x.abs().square().sum())
else:
# with torch.nn.utils.parametrize.cached():
for layer in self._layers:
x = layer(x)
if trace_powers:
powers.append(x.abs().square().sum())
if trace_powers:
return x, powers
return x
def traverse_dict_update(target, source):
for k, v in source.items():
if isinstance(v, dict):
@@ -119,6 +141,7 @@ def traverse_dict_update(target, source):
except TypeError:
target.__dict__[k] = v
class Trainer:
def __init__(
self,
@@ -142,7 +165,7 @@ class Trainer:
OptimizerSettings,
PytorchSettings,
regenerator,
torch.nn.utils.parametrizations.orthogonal
torch.nn.utils.parametrizations.orthogonal,
])
if self.resume:
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
@@ -206,6 +229,11 @@ class Trainer:
}
def define_model(self, model_kwargs=None):
if self.resume:
model_kwargs = self.checkpoint_dict["model_kwargs"]
else:
model_kwargs = model_kwargs
if model_kwargs is None:
n_hidden_layers = self.model_settings.n_hidden_layers
@@ -228,6 +256,7 @@ class Trainer:
"activation_function": afunc,
"dtype": dtype,
"dropout_prob": self.model_settings.dropout_prob,
"scale": self.model_settings.scale,
}
else:
self.model_kwargs = model_kwargs
@@ -237,9 +266,12 @@ class Trainer:
# dims = self.model_kwargs.pop("dims")
self.model = regenerator(**self.model_kwargs)
if self.writer is not None:
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
self.model = self.model.to(self.pytorch_settings.device)
if self.resume:
self.model.load_state_dict(self.checkpoint_dict["model_state_dict"], strict=False)
def get_sliced_data(self, override=None):
symbols = self.data_settings.symbols
@@ -253,11 +285,13 @@ class Trainer:
dtype = getattr(torch, self.data_settings.dtype)
num_symbols = None
config_path = self.data_settings.config_path
if override is not None:
num_symbols = override.get("num_symbols", None)
config_path = override.get("config_path", config_path)
# get dataset
dataset = FiberRegenerationDataset(
file_path=self.data_settings.config_path,
file_path=config_path,
symbols=symbols,
output_dim=data_size,
target_delay=in_out_delay,
@@ -330,10 +364,11 @@ class Trainer:
task = progress.add_task("-.---e--", total=len(train_loader))
progress.start()
running_loss2 = 0.0
# running_loss2 = 0.0
running_loss = 0.0
self.model.train()
for batch_idx, (x, y) in enumerate(train_loader):
loader_len = len(train_loader)
for batch_idx, (x, y, _) in enumerate(train_loader):
self.model.zero_grad(set_to_none=True)
x, y = (
x.to(self.pytorch_settings.device),
@@ -344,24 +379,23 @@ class Trainer:
loss_value = loss.item()
loss.backward()
optimizer.step()
running_loss2 += loss_value
# running_loss2 += loss_value
running_loss += loss_value
if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}")
progress.update(task, advance=1, description=f"{running_loss/(batch_idx+1):.3e}")
if batch_idx % self.pytorch_settings.write_every == 0:
self.writer.add_scalar(
"training loss",
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
epoch * len(train_loader) + batch_idx,
running_loss / (batch_idx + 1),
epoch * loader_len + batch_idx,
)
running_loss2 = 0.0
if enable_progress:
progress.stop()
return running_loss / len(train_loader)
return running_loss / (batch_idx + 1)
def eval_model(self, valid_loader, epoch, enable_progress=True):
if enable_progress:
@@ -384,7 +418,7 @@ class Trainer:
self.model.eval()
running_error = 0
with torch.no_grad():
for batch_idx, (x, y) in enumerate(valid_loader):
for batch_idx, (x, y, _) in enumerate(valid_loader):
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
@@ -395,15 +429,17 @@ class Trainer:
running_error += error_value
if enable_progress:
progress.update(task, advance=1, description=f"{error_value:.3e}")
progress.update(task, advance=1, description=f"{error_value/(batch_idx+1):.3e}")
running_error /= (batch_idx+1)
running_error /= len(valid_loader)
self.writer.add_scalar(
"eval loss",
running_error,
epoch,
)
if (epoch + 1) % 10 == 0 or epoch < 10:
# plotting is slow, so only do it every 10 epochs
title_append, subtitle = self.build_title(epoch + 1)
self.writer.add_figure(
"fiber response",
@@ -426,45 +462,74 @@ class Trainer:
),
epoch + 1,
)
self.writer_histograms(epoch + 1)
self.writer.add_figure(
"powers",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="powers",
show=False,
),
epoch + 1,
)
self.write_parameters(epoch + 1)
self.writer.flush()
if enable_progress:
progress.stop()
return running_error
def run_model(self, model, loader):
def run_model(self, model, loader, trace_powers=False):
model.eval()
xs = []
ys = []
y_preds = []
fiber_out = []
fiber_in = []
regen = []
timestamps = []
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for x, y in loader:
for x, y, timestamp in loader:
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x).cpu()
if trace_powers:
y_pred, powers = model(x, trace_powers).cpu()
else:
y_pred = model(x, trace_powers).cpu()
# x = x.cpu()
# y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2)
x = x.view(x.shape[0], -1, 2)
xs.append(x[:, 0, :].squeeze())
ys.append(y.squeeze())
y_preds.append(y_pred.squeeze())
# timestamp = timestamp.view(-1, 1)
fiber_out.append(x[:, x.shape[1] // 2, :].squeeze())
fiber_in.append(y.squeeze())
regen.append(y_pred.squeeze())
timestamps.append(timestamp.squeeze())
xs = torch.vstack(xs).cpu()
ys = torch.vstack(ys).cpu()
y_preds = torch.vstack(y_preds).cpu()
return ys, xs, y_preds
fiber_out = torch.vstack(fiber_out).cpu()
fiber_in = torch.vstack(fiber_in).cpu()
regen = torch.vstack(regen).cpu()
timestamps = torch.concat(timestamps).cpu()
if trace_powers:
return fiber_in, fiber_out, regen, timestamps, powers
return fiber_in, fiber_out, regen, timestamps
def writer_histograms(self, epoch, attributes=["weight", "weight_U", "weight_V", "bias", "sigma", "scale"]):
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
for i, layer in enumerate(self.model._layers):
tag = f"layer {i}"
for attribute in attributes:
if hasattr(layer, attribute):
if hasattr(layer, "parametrizations"):
attribute_pool = set(layer.parametrizations._modules) | set(layer._parameters)
else:
attribute_pool = set(layer._parameters)
for attribute in attribute_pool:
plot = (attributes is None) or (attribute in attributes)
if plot:
vals: np.ndarray = getattr(layer, attribute).detach().cpu().numpy().flatten()
if vals.ndim <= 1 and len(vals) == 1:
if np.iscomplexobj(vals):
@@ -483,14 +548,11 @@ class Trainer:
if self.writer is None:
self.setup_tb_writer()
if self.resume:
model_kwargs = self.checkpoint_dict["model_kwargs"]
else:
model_kwargs = None
self.define_model()
self.define_model(model_kwargs=model_kwargs)
print(f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})")
print(
f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})"
)
title_append, subtitle = self.build_title(0)
@@ -515,36 +577,55 @@ class Trainer:
),
0,
)
self.writer_histograms(0)
self.writer.add_figure(
"powers",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="powers",
show=False,
),
0,
)
self.write_parameters(0)
self.writer.add_text("datasets", '\n'.join(self.data_settings.config_path))
self.writer.flush()
train_loader, valid_loader = self.get_sliced_data()
optimizer_name = self.optimizer_settings.optimizer
lr = self.optimizer_settings.learning_rate
# lr = self.optimizer_settings.learning_rate
self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr)
self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)(
self.model.parameters(), **self.optimizer_settings.optimizer_kwargs
)
if self.optimizer_settings.scheduler is not None:
self.scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
self.optimizer, **self.optimizer_settings.scheduler_kwargs
)
if self.resume:
try:
self.scheduler.load_state_dict(self.checkpoint_dict["scheduler_state_dict"])
except ValueError:
pass
self.writer.add_scalar("learning rate", self.scheduler.get_last_lr()[0], -1)
# if self.resume:
# try:
# self.scheduler.load_state_dict(self.checkpoint_dict["scheduler_state_dict"])
# except ValueError:
# pass
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], -1)
if not self.resume:
self.best = self.build_checkpoint_dict()
else:
self.best = self.checkpoint_dict
self.model.load_state_dict(self.best["model_state_dict"], strict=False)
try:
self.optimizer.load_state_dict(self.best["optimizer_state_dict"])
except ValueError:
pass
self.best["loss"] = float("inf")
# self.model.load_state_dict(self.best["model_state_dict"], strict=False)
# try:
# self.optimizer.load_state_dict(self.best["optimizer_state_dict"])
# except ValueError:
# pass
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
enable_progress = True
@@ -562,12 +643,8 @@ class Trainer:
enable_progress=enable_progress,
)
if self.optimizer_settings.scheduler is not None:
lr_old = self.scheduler.get_last_lr()
self.scheduler.step(loss)
lr_new = self.scheduler.get_last_lr()
if lr_old[0] != lr_new[0]:
self.writer.add_scalar("learning rate", lr_new[0], epoch)
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch)
if self.pytorch_settings.save_models and self.model is not None:
save_path = (
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
@@ -588,7 +665,28 @@ class Trainer:
self.writer.close()
return self.best
def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
def _plot_model_response_powers(self, powers, layer_names, title_append="", subtitle="", show=True):
powers = [power / powers[0] for power in powers]
fig, ax = plt.subplots()
fig.set_figwidth(18)
fig.suptitle(
f"Energy conservation{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
)
ax.semilogy(powers, marker="o")
ax.set_xticks(range(len(layer_names)), layer_names, rotation=90)
ax.set_xlabel("Layer")
ax.set_ylabel("Normailzed Power")
ax.grid(which="major", axis="x")
ax.grid(which="major", axis="y")
ax.grid(which="minor", axis="y", linestyle=":")
fig.tight_layout()
if show:
plt.show()
return fig
def _plot_model_response_eye(
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
):
if sps is None:
raise ValueError("sps must be provided")
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
@@ -603,16 +701,67 @@ class Trainer:
if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))]
x_bins = np.linspace(0, 2, 2 * sps, endpoint=False)
y_bins = np.zeros((2 * len(signals), 1000))
eye_data = np.zeros((2 * len(signals), 1000, 2 * sps))
# signals = [signal.cpu().numpy() for signal in signals]
for i in range(len(signals) * 2):
eye_signal = signals[i // 2][:, i % 2] # x, y, x, y, ...
eye_signal = np.real(np.square(np.abs(eye_signal)))
data_min = np.min(eye_signal)
data_max = np.max(eye_signal)
y_bins[i] = np.linspace(data_min, data_max, 1000, endpoint=False)
for j in range(len(timestamps)):
t = timestamps[j] / sps
val = eye_signal[j]
x = np.digitize(t % 2, x_bins) - 1
y = np.digitize(val, y_bins[i]) - 1
eye_data[i][y][x] += 1
cmap = LinearSegmentedColormap.from_list(
"eyemap",
[
(0, "white"),
(0.001, "dodgerblue"),
(0.1, "blue"),
(0.2, "cyan"),
(0.5, "lime"),
(0.8, "gold"),
(1, "red"),
],
)
# ordering = np.argsort(timestamps)
# signals = [signal[ordering] for signal in signals]
# timestamps = timestamps[ordering]
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
fig.set_figwidth(18)
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
for j, (label, signal) in enumerate(zip(labels, signals)):
# xaxis = timestamps / sps
# xaxis = np.arange(2 * sps) / sps
for j, label in enumerate(labels):
x = eye_data[2 * j]
y = eye_data[2 * j + 1]
# x, y = signal.T
# signal = signal.cpu().numpy()
for i in range(len(signal) // sps - 1):
x, y = signal[i * sps : (i + 2) * sps].T
axs[0 + 2 * j].plot(xaxis, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10)
axs[1 + 2 * j].plot(xaxis, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10)
# for i in range(len(signal) // sps - 1):
# x, y = signal[i * sps : (i + 2) * sps].T
# axs[0 + 2 * j].scatter((timestamps/sps) % 2, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1)
# axs[1 + 2 * j].scatter((timestamps/sps) % 2, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1)
axs[0 + 2 * j].imshow(
x, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j][0], y_bins[2 * j][-1]]
)
axs[1 + 2 * j].imshow(
y, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j + 1][0], y_bins[2 * j + 1][-1]]
)
axs[0 + 2 * j].set_xlim((x_bins[0], x_bins[-1]))
axs[1 + 2 * j].set_xlim((x_bins[0], x_bins[-1]))
ymin = np.min(y_bins[:, 0])
ymax = np.max(y_bins[:, -1])
ydiff = ymax - ymin
axs[0 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff))
axs[1 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff))
axs[0 + 2 * j].set_title(label + " x")
axs[1 + 2 * j].set_title(label + " y")
axs[0 + 2 * j].set_xlabel("Symbol")
@@ -627,7 +776,9 @@ class Trainer:
plt.show()
return fig
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
def _plot_model_response_head(
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
):
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
labels = [labels]
else:
@@ -640,19 +791,29 @@ class Trainer:
if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))]
ordering = np.argsort(timestamps)
signals = [signal[ordering] for signal in signals]
timestamps = timestamps[ordering]
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
fig.set_figwidth(18)
fig.set_figheight(4)
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
for i, ax in enumerate(axs):
ax: plt.Axes
for signal, label in zip(signals, labels):
if sps is not None:
xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False)
xaxis = timestamps / sps
else:
xaxis = np.arange(len(signal))
xaxis = timestamps
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
ax.set_xlabel("Sample" if sps is None else "Symbol")
ax.set_ylabel("normalized power")
ax.minorticks_on()
ax.tick_params(axis="y", which="minor", left=False, right=False)
ax.grid(which="major", axis="x")
ax.grid(which="minor", axis="x", linestyle=":")
ax.grid(which="major", axis="y")
ax.legend(loc="upper right")
fig.tight_layout()
if show:
@@ -664,22 +825,51 @@ class Trainer:
model=None,
title_append="",
subtitle="",
mode: Literal["eye", "head"] = "head",
mode: Literal["eye", "head", "powers"] = "head",
show=False,
):
if mode == "powers":
input_data = torch.ones(
1, 2 * self.data_settings.output_size, dtype=getattr(torch, self.data_settings.dtype)
).to(self.pytorch_settings.device)
model = model.to(self.pytorch_settings.device)
model.eval()
with torch.no_grad():
_, powers = model(input_data, trace_powers=True)
powers = [power.item() for power in powers]
layer_names = ["input", *[str(x).split("(")[0] for x in model._layers._modules.values()]]
# remove dropout layers
mask = [1 if "Dropout" not in layer_name else 0 for layer_name in layer_names]
layer_names = [layer_name for layer_name, m in zip(layer_names, mask) if m]
powers = [power for power, m in zip(powers, mask) if m]
fig = self._plot_model_response_powers(
powers, layer_names, title_append=title_append, subtitle=subtitle, show=show
)
return fig
data_settings_backup = copy.deepcopy(self.data_settings)
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
self.data_settings.drop_first = 100 * 128
self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
self.data_settings.shuffle = False
self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = (
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
)
plot_loader, _ = self.get_sliced_data(override={"num_symbols": self.pytorch_settings.batchsize})
config_path = random.choice(self.data_settings.config_path)
fiber_length = int(float(str(config_path).split('-')[-7])/1000)
plot_loader, _ = self.get_sliced_data(
override={
"num_symbols": self.pytorch_settings.batchsize,
"config_path": config_path,
}
)
self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup
fiber_in, fiber_out, regen = self.run_model(model, plot_loader)
fiber_in, fiber_out, regen, timestamps = self.run_model(model, plot_loader)
fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2)
regen = regen.view(-1, 2)
@@ -687,6 +877,7 @@ class Trainer:
fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy()
regen = regen.numpy()
timestamps = timestamps.numpy()
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
@@ -697,9 +888,10 @@ class Trainer:
fiber_in,
fiber_out,
regen,
timestamps=timestamps,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
@@ -709,9 +901,10 @@ class Trainer:
fiber_in,
fiber_out,
regen,
timestamps=timestamps,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)

View File

@@ -1,3 +1,6 @@
import matplotlib
import numpy as np
import torch
from hypertraining.settings import (
GlobalSettings,
DataSettings,
@@ -7,16 +10,20 @@ from hypertraining.settings import (
)
from hypertraining.training import Trainer
import torch
# import torch
import json
import util
from rich import print as rprint
global_settings = GlobalSettings(
seed=42,
seed=0xC0FFEE,
)
data_settings = DataSettings(
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
# config_path="data/*-128-16384-50000-0-0-17-0-PAM4-0.ini",
config_path=[f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in (40000, 50000, 60000)],
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232
@@ -45,55 +52,83 @@ model_settings = ModelSettings(
output_dim=2,
n_hidden_layers=4,
overrides={
"n_hidden_nodes_0": 8,
"n_hidden_nodes_1": 8,
"n_hidden_nodes_0": 4,
"n_hidden_nodes_1": 4,
"n_hidden_nodes_2": 4,
"n_hidden_nodes_3": 6,
"n_hidden_nodes_3": 4,
},
model_activation_func="PowScale",
# dropout_prob=0.01,
model_layer_function="ONN",
model_activation_func="EOActivation",
dropout_prob=0.01,
model_layer_function="ONNRect",
model_layer_kwargs={"square": True},
scale=True,
model_layer_parametrizations=[
{
"tensor_name": "weight",
"parametrization": torch.nn.utils.parametrizations.orthogonal,
"parametrization": util.complexNN.energy_conserving,
},
{
"tensor_name": "alpha",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "gain",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": float("inf"),
},
},
{
"tensor_name": "phase_bias",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2 * torch.pi,
},
},
{
"tensor_name": "scales",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "scale",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "bias",
"parametrization": util.complexNN.clamp,
},
# {
# "tensor_name": "scale",
# "parametrization": util.complexNN.clamp,
# },
# {
# "tensor_name": "bias",
# "parametrization": util.complexNN.clamp,
# },
# {
# "tensor_name": "V",
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
# },
# {
# "tensor_name": "S",
# "parametrization": util.complexNN.clamp,
# },
{
"tensor_name": "loss",
"parametrization": util.complexNN.clamp,
},
],
)
optimizer_settings = OptimizerSettings(
optimizer="Adam",
learning_rate=0.05,
optimizer="AdamW",
optimizer_kwargs={
"lr": 0.05,
"amsgrad": True,
# "weight_decay": 1e-7,
},
# learning_rate=0.05,
scheduler="ReduceLROnPlateau",
scheduler_kwargs={
"patience": 2**6,
"factor": 0.9,
"factor": 0.75,
# "threshold": 1e-3,
"min_lr": 1e-6,
"cooldown": 10,
},
)
def save_dict_to_file(dictionary, filename):
"""
Save the best dictionary to a JSON file.
@@ -103,28 +138,79 @@ def save_dict_to_file(dictionary, filename):
:param filename: Path to the JSON file where the dictionary will be saved.
:type filename: str
"""
with open(filename, 'w') as f:
with open(filename, "w") as f:
json.dump(dictionary, f, indent=4)
def sweep_lengths(*lengths, model=None):
assert model is not None, "Model must be provided."
model = model
fiber_ins = {}
fiber_outs = {}
regens = {}
timestampss = {}
for length in lengths:
trainer = Trainer(
checkpoint_path=model,
settings_override={
"data_settings": {
"config_path": f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini",
"train_split": 1,
"shuffle": True,
}
},
)
trainer.define_model()
loader, _ = trainer.get_sliced_data()
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
fiber_ins[length] = fiber_in
fiber_outs[length] = fiber_out
regens[length] = regen
timestampss[length] = timestamps
data = torch.zeros(2 * len(lengths), 2, fiber_out.shape[0])
channel_names = ["" for _ in range(2 * len(lengths))]
for li, length in enumerate(lengths):
data[2 * li, 0, :] = timestampss[length] / 128
data[2 * li, 1, :] = regens[length][:, 0].abs().square()
data[2 * li + 1, 0, :] = timestampss[length] / 128
data[2 * li + 1, 1, :] = regens[length][:, 1].abs().square()
channel_names[2 * li] = f"regen x {length}"
channel_names[2 * li + 1] = f"regen y {length}"
# get current backend
backend = matplotlib.get_backend()
matplotlib.use("TkCairo")
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
print_attrs = ("channel", "success", "min_area")
with np.printoptions(precision=3, suppress=True, formatter={'float': '{:0.3e}'.format}):
for result in eye.eye_stats:
print_dict = {attr: result[attr] for attr in print_attrs}
rprint(print_dict)
rprint()
eye.plot()
matplotlib.use(backend)
if __name__ == "__main__":
# sweep_lengths(30000, 40000, 50000, 60000, 70000, model=".models/best_20241202_143149.tar")
trainer = Trainer(
global_settings=global_settings,
data_settings=data_settings,
pytorch_settings=pytorch_settings,
model_settings=model_settings,
optimizer_settings=optimizer_settings,
checkpoint_path='.models/20241128_084935_8885.tar',
settings_override={
"model_settings": {
# "model_activation_func": "PowScale",
"dropout_prob": 0,
}
},
reset_epoch=True,
# checkpoint_path=".models/best_20241202_143149.tar",
# 20241202_143149
)
best = trainer.train()
save_dict_to_file(best, ".models/best_results.json")
...
trainer.train()

View File

@@ -0,0 +1,88 @@
# move into dir single-core-regen before running
from util.datasets import FiberRegenerationDataset
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import numpy as np
# def eye_dataset(dataset, no_symbols=None, offset=False, show=True):
# if no_symbols is None:
# no_symbols = len(dataset)
# _, axs = plt.subplots(2,2, sharex=True, sharey=True)
# xaxis = np.linspace(0,dataset.symbols_per_slice,dataset.samples_per_slice)
# roll = dataset.samples_per_symbol//2 if offset else 0
# for E_out, E_in in dataset[roll:dataset.samples_per_symbol*no_symbols+roll:dataset.samples_per_symbol]:
# E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1]
# axs[0,0].plot(xaxis, np.abs( E_in_x.numpy())**2, alpha=0.05, color='C0')
# axs[1,0].plot(xaxis, np.abs( E_in_y.numpy())**2, alpha=0.05, color='C0')
# axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2, alpha=0.05, color='C0')
# axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2, alpha=0.05, color='C0')
# if show:
# plt.show()
# # def plt_dataloader(dataloader, show=True):
# # _, axs = plt.subplots(2,2, sharex=True, sharey=True)
# # E_outs, E_ins = next(iter(dataloader))
# # for i, (E_out, E_in) in enumerate(zip(E_outs, E_ins)):
# # xaxis = np.linspace(dataset.symbols_per_slice*i,dataset.symbols_per_slice+dataset.symbols_per_slice*i,dataset.samples_per_slice)
# # E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1]
# # axs[0,0].plot(xaxis, np.abs(E_in_x.numpy())**2)
# # axs[1,0].plot(xaxis, np.abs(E_in_y.numpy())**2)
# # axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2)
# # axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2)
# # if show:
# # plt.show()
if __name__ == "__main__":
dataset = FiberRegenerationDataset("data/20241115-175517-128-16384-10000-0-0-17-0-PAM4-0.ini", symbols=13, drop_first=100, output_dim=26, num_symbols=100)
loader = DataLoader(dataset, batch_size=10, shuffle=True)
x = []
y_fiber_in = []
y_fiber_out = []
for i, batch in enumerate(loader):
# if i > 128:
# break
fiber_in, fiber_out, timestamp = batch
fiber_out = fiber_out.reshape(fiber_out.shape[0], -1, 2)
fiber_out = fiber_out[:,fiber_out.shape[1]//2, :]
# input_data = input_data.reshape(-1,2)
# target = target.reshape(-1,2).squeeze()
# timestamp = timestamp.reshape(-1,1).squeeze()
x.append(timestamp.detach().numpy())
y_fiber_in.append(fiber_in.abs().square().detach().numpy())
y_fiber_out.append(fiber_out.abs().square().detach().numpy())
x = np.concat(x)
y_fiber_in = np.concat(y_fiber_in)
y_fiber_out = np.concat(y_fiber_out)
# order = np.argsort(x)
# x = x[order]
# y = y[order]
fig, axs = plt.subplots(2,2, sharex=True, sharey=True)
axs[0,0].scatter((x/dataset.samples_per_symbol)%2, y_fiber_in[:,0], s=1, alpha=0.1)
axs[1,0].scatter((x/dataset.samples_per_symbol)%2, y_fiber_in[:,1], s=1, alpha=0.1)
axs[0,1].scatter((x/dataset.samples_per_symbol)%2, y_fiber_out[:,0], s=1, alpha=0.1)
axs[1,1].scatter((x/dataset.samples_per_symbol)%2, y_fiber_out[:,1], s=1, alpha=0.1)
plt.show()
# eye_dataset(dataset, 1000, offset=True, show=False)
# train_loader = DataLoader(dataset, batch_size=10, shuffle=False)
# plt_dataloader(train_loader, show=False)
# plt.show()

View File

@@ -1,51 +0,0 @@
# move into dir single-core-regen before running
from util.dataset import SlicedDataset
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import numpy as np
def eye_dataset(dataset, no_symbols=None, offset=False, show=True):
if no_symbols is None:
no_symbols = len(dataset)
_, axs = plt.subplots(2,2, sharex=True, sharey=True)
xaxis = np.linspace(0,dataset.symbols_per_slice,dataset.samples_per_slice)
roll = dataset.samples_per_symbol//2 if offset else 0
for E_out, E_in in dataset[roll:dataset.samples_per_symbol*no_symbols+roll:dataset.samples_per_symbol]:
E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1]
axs[0,0].plot(xaxis, np.abs( E_in_x.numpy())**2, alpha=0.05, color='C0')
axs[1,0].plot(xaxis, np.abs( E_in_y.numpy())**2, alpha=0.05, color='C0')
axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2, alpha=0.05, color='C0')
axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2, alpha=0.05, color='C0')
if show:
plt.show()
# def plt_dataloader(dataloader, show=True):
# _, axs = plt.subplots(2,2, sharex=True, sharey=True)
# E_outs, E_ins = next(iter(dataloader))
# for i, (E_out, E_in) in enumerate(zip(E_outs, E_ins)):
# xaxis = np.linspace(dataset.symbols_per_slice*i,dataset.symbols_per_slice+dataset.symbols_per_slice*i,dataset.samples_per_slice)
# E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1]
# axs[0,0].plot(xaxis, np.abs(E_in_x.numpy())**2)
# axs[1,0].plot(xaxis, np.abs(E_in_y.numpy())**2)
# axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2)
# axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2)
# if show:
# plt.show()
if __name__ == "__main__":
dataset = SlicedDataset("data/20241115-175517-128-16384-10000-0-0-17-0-PAM4-0.ini", symbols=1, drop_first=100)
print(dataset[0][0].shape)
eye_dataset(dataset, 1000, offset=True, show=False)
train_loader = DataLoader(dataset, batch_size=10, shuffle=False)
# plt_dataloader(train_loader, show=False)
plt.show()

View File

@@ -17,3 +17,5 @@ from . import complexNN # noqa: F401
# from .complexNN import complex_sse_loss # noqa: F401
from . import misc # noqa: F401
from . import eye_diagram # noqa: F401

View File

@@ -4,23 +4,36 @@ import torch.nn.functional as F
# from torchlambertw.special import lambertw
def complex_mse_loss(input, target, power=False, reduction="mean"):
def complex_mse_loss(input, target, power=False, normalize=False, reduction="mean"):
"""
Compute the mean squared error between two complex tensors.
If power is set to True, the loss is computed as |input|^2 - |target|^2
"""
reduce = getattr(torch, reduction)
power_penalty = 0
if power:
input = (input * input.conj()).real.to(dtype=input.dtype.to_real())
target = (target * target.conj()).real.to(dtype=target.dtype.to_real())
if normalize:
power_penalty = ((input.max() - input.min()) - (target.max() - target.min())) ** 2
power_penalty += (input.min() - target.min()) ** 2
input = input - input.min()
input = input / input.max()
target = target - target.min()
target = target / target.max()
else:
if normalize:
power_penalty = (input.abs().max() - target.abs().max()) ** 2
input = input / input.abs().max()
target = target / target.abs().max()
if input.is_complex() and target.is_complex():
return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag)) + power_penalty
elif input.is_complex() or target.is_complex():
raise ValueError("Input and target must have the same type (real or complex)")
else:
return F.mse_loss(input, target, reduction=reduction)
return F.mse_loss(input, target, reduction=reduction) + power_penalty
def complex_sse_loss(input, target):
@@ -53,14 +66,10 @@ class UnitaryLayer(nn.Module):
return f"UnitaryLayer({self.in_features}, {self.out_features})"
class _Unitary(nn.Module):
def forward(self, X: torch.Tensor):
if X.ndim < 2:
raise ValueError(
"Only tensors with 2 or more dimensions are supported. "
f"Got a tensor of shape {X.shape}"
)
raise ValueError(f"Only tensors with 2 or more dimensions are supported. Got a tensor of shape {X.shape}")
n, k = X.size(-2), X.size(-1)
transpose = n < k
if transpose:
@@ -80,6 +89,7 @@ class _Unitary(nn.Module):
# X.copy_(q)
return q
def unitary(module: nn.Module, name: str = "weight") -> nn.Module:
weight = getattr(module, name, None)
if not isinstance(weight, torch.Tensor):
@@ -95,6 +105,7 @@ def unitary(module: nn.Module, name: str = "weight") -> nn.Module:
nn.utils.parametrize.register_parametrization(module, name, unit)
return module
class _SpecialUnitary(nn.Module):
def __init__(self):
super().__init__()
@@ -108,6 +119,7 @@ class _SpecialUnitary(nn.Module):
return q
def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module:
weight = getattr(module, name, None)
if not isinstance(weight, torch.Tensor):
@@ -123,11 +135,13 @@ def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module:
nn.utils.parametrize.register_parametrization(module, name, unit)
return module
class _Clamp(nn.Module):
def __init__(self, min, max):
super(_Clamp, self).__init__()
self.min = min
self.max = max
def forward(self, x):
if x.is_complex():
# clamp magnitude, ignore phase
@@ -145,43 +159,29 @@ def clamp(module: nn.Module, name: str = "scale", min=0, max=1) -> nn.Module:
return module
class ONNMiller(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None) -> None:
super(ONNMiller, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.dtype = dtype
class _EnergyConserving(nn.Module):
def __init__(self):
super(_EnergyConserving, self).__init__()
self.dim = max(input_dim, output_dim)
def forward(self, X: torch.Tensor):
if X.ndim == 2:
X = X.unsqueeze(0)
spectral_norm = torch.linalg.svdvals(X)[:, 0]
return (X / spectral_norm).squeeze()
# zero pad input to internal size if smaller
if self.input_dim < self.dim:
self.pad = lambda x: F.pad(x, ((self.dim - self.input_dim) // 2, (self.dim - self.input_dim + 1) // 2))
else:
self.pad = lambda x: x
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {self.dim}"
# crop output to desired size
if self.output_dim < self.dim:
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)]
else:
self.crop = lambda x: x
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
def energy_conserving(module: nn.Module, name: str = "weight") -> nn.Module:
param = getattr(module, name, None)
if not isinstance(param, torch.Tensor):
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
self.U = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
self.S = nn.Parameter(torch.randn(self.dim, dtype=self.dtype)) # -> parametrization: Clamp (magnitude 0..1)
self.V = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
self.register_buffer("MZI_scale", torch.tensor(2, dtype=self.dtype.to_real()).sqrt())
# V is actually V.H, but
if not (2 <= param.ndim <= 3):
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {param.ndim} dimensions.")
unit = _EnergyConserving()
nn.utils.parametrize.register_parametrization(module, name, unit)
return module
def forward(self, x_in):
x = x_in
x = self.pad(x)
x = x @ self.U
x = x * (self.S.squeeze() / self.MZI_scale)
x = x @ self.V
x = self.crop(x)
return x
class ONN(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None) -> None:
@@ -202,18 +202,21 @@ class ONN(nn.Module):
# crop output to desired size
if self.output_dim < self.dim:
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)]
self.crop = lambda x: x[
:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)
]
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
else:
self.crop = lambda x: x
self.crop.__doc__ = f"Output size equals internal size {self.dim}"
self.weight = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype))
# self.scale = nn.Parameter(torch.randn(1, dtype=self.dtype.to_real())+0.5)
def reset_parameters(self):
q, _ = torch.linalg.qr(self.weight)
self.weight.data = q
# def get_M(self):
# return self.U @ self.sigma @ self.V
@@ -221,37 +224,50 @@ class ONN(nn.Module):
return self.crop(self.pad(x) @ self.weight)
class SemiUnitaryLayer(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None):
super(SemiUnitaryLayer, self).__init__()
class ONNRect(nn.Module):
def __init__(self, input_dim, output_dim, square=False, dtype=None):
super(ONNRect, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
# Create a larger square matrix for QR decomposition
self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype))
self.scale = nn.Parameter(torch.tensor(1.0, dtype=dtype.to_real()))
self.reset_parameters()
if square:
dim = max(input_dim, output_dim)
self.weight = nn.Parameter(torch.randn(dim, dim, dtype=dtype))
def reset_parameters(self):
# Ensure the weights are unitary by QR decomposition
q, _ = torch.linalg.qr(self.weight)
# A = QR with A being a complex square matrix -> Q is unitary, R is upper triangular
# truncate the matrix to the desired size
if self.input_dim > self.output_dim:
self.weight.data = q[: self.input_dim, : self.output_dim]
# zero pad input to internal size if smaller
if self.input_dim < dim:
self.pad = lambda x: F.pad(x, ((dim - self.input_dim) // 2, (dim - self.input_dim + 1) // 2))
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {dim}"
else:
self.weight.data = q[: self.output_dim, : self.input_dim].t()
...
self.pad = lambda x: x
self.pad.__doc__ = f"Input size equals internal size {dim}"
# crop output to desired size
if self.output_dim < dim:
self.crop = lambda x: x[
:, (dim - self.output_dim) // 2 : (x.shape[1] - (dim - self.output_dim + 1) // 2)
]
self.crop.__doc__ = f"Crop output from {dim} to {self.output_dim}"
else:
self.crop = lambda x: x
self.crop.__doc__ = f"Output size equals internal size {dim}"
else:
self.weight = nn.Parameter(torch.randn(output_dim, input_dim, dtype=dtype))
self.pad = lambda x: x
self.pad.__doc__ = "No padding"
self.crop = lambda x: x
self.crop.__doc__ = "No cropping"
def forward(self, x):
with torch.no_grad():
scale = torch.clamp(self.scale, 0.0, 1.0)
out = torch.matmul(x, scale * self.weight)
x = self.pad(x)
out = self.crop((self.weight @ x.mT).mT)
return out
def __repr__(self):
return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})"
# def __repr__(self):
# return f"ONNRect({self.input_dim}, {self.output_dim})"
# class SaturableAbsorberLambertW(nn.Module):
@@ -336,6 +352,19 @@ class DropoutComplex(nn.Module):
return self.dropout(x)
class Scale(nn.Module):
def __init__(self, size):
super(Scale, self).__init__()
self.size = size
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
def forward(self, x):
return x * self.scale
def __repr__(self):
return f"Scale({self.size})"
class Identity(nn.Module):
"""
implements the "activation" function
@@ -348,6 +377,7 @@ class Identity(nn.Module):
def forward(self, x):
return x
class PowRot(nn.Module):
def __init__(self, bias=False):
super(PowRot, self).__init__()
@@ -363,11 +393,71 @@ class PowRot(nn.Module):
else:
return x
class MZISingle(nn.Module):
def __init__(self, bias, size, func=None):
super(MZISingle, self).__init__()
self.omega = nn.Parameter(torch.randn(size))
self.phi = nn.Parameter(torch.randn(size))
self.func = func or (lambda x: x.abs().square()) # default to |z|^2
def forward(self, x: torch.Tensor):
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
class EOActivation(nn.Module):
def __init__(self, bias, size=None):
# 10.1109/SiPhotonics60897.2024.10543376
super(EOActivation, self).__init__()
if size is None:
raise ValueError("Size must be specified")
self.size = size
self.alpha = nn.Parameter(torch.ones(size))
self.V_bias = nn.Parameter(torch.ones(size))
self.gain = nn.Parameter(torch.ones(size))
# if bias:
# self.phase_bias = nn.Parameter(torch.zeros(size))
# else:
# self.register_buffer("phase_bias", torch.zeros(size))
self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
self.register_buffer("responsivity", torch.ones(size)*0.9)
self.register_buffer("V_pi", torch.ones(size)*3)
self.reset_weights()
def reset_weights(self):
if "alpha" in self._parameters:
self.alpha.data = torch.ones(self.size)*0.5
if "V_pi" in self._parameters:
self.V_pi.data = torch.ones(self.size)*3
if "V_bias" in self._parameters:
self.V_bias.data = torch.zeros(self.size)
if "gain" in self._parameters:
self.gain.data = torch.ones(self.size)
if "responsivity" in self._parameters:
self.responsivity.data = torch.ones(self.size)*0.9
if "bias" in self._parameters:
self.phase_bias.data = torch.zeros(self.size)
def forward(self, x: torch.Tensor):
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8)
intermediate = g_phi * x.abs().square() + phi_b
return (
1j
* torch.sqrt(1 - self.alpha)
* torch.exp(-0.5j * (intermediate + self.phase_bias))
* torch.cos(0.5 * intermediate)
* x
)
class Pow(nn.Module):
"""
implements the activation function
M(z) = ||z||^2 + b
"""
def __init__(self, bias=False):
super(Pow, self).__init__()
if bias:
@@ -375,7 +465,6 @@ class Pow(nn.Module):
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x: torch.Tensor):
return x.abs().square().add(self.bias).to(dtype=x.dtype)
@@ -408,6 +497,7 @@ class MagScale(nn.Module):
def forward(self, x: torch.Tensor):
return x.abs().add(self.bias).to(dtype=x.dtype).sin().mul(x)
class PowScale(nn.Module):
def __init__(self, bias=False):
super(PowScale, self).__init__()
@@ -486,10 +576,10 @@ __all__ = [
complex_mse_loss,
UnitaryLayer,
unitary,
energy_conserving,
clamp,
ONN,
ONNMiller,
SemiUnitaryLayer,
ONNRect,
DropoutComplex,
Identity,
Pow,
@@ -498,6 +588,8 @@ __all__ = [
ModReLU,
CReLU,
ZReLU,
MZISingle,
EOActivation,
# SaturableAbsorberLambertW,
# SaturableAbsorber,
# SpreadLayer,

View File

@@ -40,7 +40,8 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
if symbols is None:
symbols = int(config["glova"]["nos"]) - skipfirst
data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps]
data = np.load(datapath)[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)]
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps))
if normalize:
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
@@ -53,6 +54,8 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
config["glova"]["nos"] = str(symbols)
data = np.concatenate([data, timestamps.reshape(-1,1)], axis=-1)
data = torch.tensor(data, device=device, dtype=dtype)
return data, config
@@ -100,7 +103,7 @@ class FiberRegenerationDataset(Dataset):
def __init__(
self,
file_path: str | Path,
file_path: tuple | list | str | Path,
symbols: int | float,
*,
output_dim: int = None,
@@ -130,12 +133,12 @@ class FiberRegenerationDataset(Dataset):
"""
# check types
assert isinstance(file_path, str), "file_path must be a string"
assert isinstance(file_path, (str, Path, tuple, list)), "file_path must be a string, Path, tuple, or list"
assert isinstance(symbols, (float, int)), "symbols must be a float or an integer"
assert output_dim is None or isinstance(output_dim, int), "output_len must be an integer"
assert isinstance(target_delay, (float, int)), "target_delay must be a float or an integer"
assert isinstance(xy_delay, (float, int)), "xy_delay must be a float or an integer"
assert isinstance(drop_first, int), "drop_first must be an integer"
# assert isinstance(drop_first, int), "drop_first must be an integer"
# check values
assert symbols > 0, "symbols must be positive"
@@ -150,20 +153,38 @@ class FiberRegenerationDataset(Dataset):
dtype=np.complex128,
)
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
timestamps = torch.arange(12800)
data_raw = torch.concatenate([data_raw, timestamps.reshape(-1, 1)], axis=-1)
self.config = {
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
"glova": {"sps": 128},
}
else:
data_raw, self.config = load_data(
data_raw = None
self.config = None
files = []
for file_path in (file_path if isinstance(file_path, (tuple, list)) else [file_path]):
data, config = load_data(
file_path,
skipfirst=drop_first,
symbols=kwargs.pop("num_symbols", None),
symbols=kwargs.get("num_symbols", None),
real=real,
normalize=True,
device=device,
dtype=dtype,
)
if data_raw is None:
data_raw = data
else:
data_raw = torch.cat([data_raw, data], dim=0)
if self.config is None:
self.config = config
else:
assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same"
files.append(config["data"]["file"].strip('"'))
self.config["data"]["file"] = str(files)
self.device = data_raw.device
@@ -190,10 +211,10 @@ class FiberRegenerationDataset(Dataset):
# data_raw = torch.tensor(data_raw, dtype=dtype)
# data layout
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0],
# [E_in_x1, E_in_y1, E_out_x1, E_out_y1],
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0, timestamp0],
# [E_in_x1, E_in_y1, E_out_x1, E_out_y1, timestamp1],
# ...
# [E_in_xN, E_in_yN, E_out_xN, E_out_yN] ]
# [E_in_xN, E_in_yN, E_out_xN, E_out_yN, timestampN] ]
data_raw = data_raw.transpose(0, 1)
@@ -201,16 +222,18 @@ class FiberRegenerationDataset(Dataset):
# [ E_in_x[0:N],
# E_in_y[0:N],
# E_out_x[0:N],
# E_out_y[0:N] ]
# E_out_y[0:N],
# timestamps[0:N] ]
# shift x data by xy_delay_samples relative to the y data (example value: 3)
# [ E_in_x [0:N], [ E_in_x [ 0:N ], [ E_in_x [3:N ],
# E_in_y [0:N], -> E_in_y [-3:N-3], -> E_in_y [0:N-3],
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[3:N ],
# E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ]
# E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3],
# timestamps[0:N] ] timestamps[ 0:N ] ] timestamps[3:N ] ]
if self.xy_delay_samples != 0:
data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1)
data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples, 0], dim=1)
if self.xy_delay_samples > 0:
data_raw = data_raw[:, self.xy_delay_samples :]
elif self.xy_delay_samples < 0:
@@ -221,12 +244,13 @@ class FiberRegenerationDataset(Dataset):
# [ E_in_x [0:N], [ E_in_x [-5:N-5], [ E_in_x [0:N-5],
# E_in_y [0:N], -> E_in_y [-5:N-5], -> E_in_y [0:N-5],
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[5:N ],
# E_out_y[0:N] ] E_out_y[ 0:N ] ] E_out_y[5:N ] ]
# E_out_y[0:N] ] E_out_y[ 0:N ] ] E_out_y[5:N ],
# timestamps[0:N] ] timestamps[ 0:N ] ] timestamps[5:N ]
if self.target_delay_samples != 0:
data_raw = roll_along(
data_raw,
[self.target_delay_samples, self.target_delay_samples, 0, 0],
[self.target_delay_samples, self.target_delay_samples, 0, 0, 0],
dim=1,
)
if self.target_delay_samples > 0:
@@ -234,21 +258,25 @@ class FiberRegenerationDataset(Dataset):
elif self.target_delay_samples < 0:
data_raw = data_raw[:, : self.target_delay_samples]
timestamps = data_raw[-1, :]
data_raw = data_raw[:-1, :]
data_raw = data_raw.view(2, 2, -1)
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(dim=1)
data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
# data layout
# [ [E_in_x, E_in_y],
# [E_out_x, E_out_y] ]
# [ [E_in_x, E_in_y, timestamps],
# [E_out_x, E_out_y, timestamps] ]
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.data = self.data.movedim(-2, 0)
# -> [no_slices, 2, 2, samples_per_slice]
# -> [no_slices, 2, 3, samples_per_slice]
# data layout
# [
# [ [E_in_x[0:N+0], E_in_y[0:N+0] ], [ E_out_x[0:N+0], E_out_y[0:N+0] ] ],
# [ [E_in_x[1:N+1], E_in_y[1:N+1] ], [ E_out_x[1:N+1], E_out_y[1:N+1] ] ],
# [ [E_in_x[0:N+0], E_in_y[0:N+0], timestamps[0:N+0]], [ E_out_x[0:N+0], E_out_y[0:N+0], timestamps[0:N+0] ] ],
# [ [E_in_x[1:N+1], E_in_y[1:N+1], timestamps[1:N+1]], [ E_out_x[1:N+1], E_out_y[1:N+1], timestamps[1:N+1] ] ],
# ...
# ] -> [no_slices, 2, 2, samples_per_slice]
# ] -> [no_slices, 2, 3, samples_per_slice]
...
@@ -259,24 +287,24 @@ class FiberRegenerationDataset(Dataset):
if isinstance(idx, slice):
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
else:
data, target = self.data[idx, 1].squeeze(), self.data[idx, 0].squeeze()
data_slice = self.data[idx].squeeze()
# reduce by by taking self.output_dim equally spaced samples
data = data[:, : data.shape[1] // self.output_dim * self.output_dim]
data = data.view(data.shape[0], self.output_dim, -1)
data = data[:, :, 0]
data_slice = data_slice[:, :, :data_slice.shape[2] // self.output_dim * self.output_dim]
data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
target = data_slice[0, :, self.output_dim//2, 0]
data = data_slice[1, :, :, 0]
# data_timestamps = data[-1,:].real
data = data[:-1, :]
target_timestamp = target[-1].real
target = target[:-1]
# target is corresponding to the middle of the data as the output sample is influenced by the data before and after it
target = target[:, : target.shape[1] // self.output_dim * self.output_dim]
target = target.view(target.shape[0], self.output_dim, -1)
target = target[:, 0, target.shape[2] // 2]
data = data.transpose(0, 1).flatten().squeeze()
# data_timestamps = data_timestamps.flatten().squeeze()
target = target.flatten().squeeze()
target_timestamp = target_timestamp.flatten().squeeze()
# data layout:
# [sample_x0, sample_y0, sample_x1, sample_y1, ...]
# target layout:
# [sample_x0, sample_y0]
return data, target
return data, target, target_timestamp

View File

@@ -0,0 +1,418 @@
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import numpy as np
from scipy.cluster.vq import kmeans2
import warnings
from rich.traceback import install
from rich import pretty
from rich import print
install()
pretty.install()
def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
data = create_symbol_sequence(n_symbols, skew=skew)
signal = generate_signal(data, sps)
signal = normalization_with_noise(signal, noise)
xaxis = np.arange(0, len(signal)) / sps
return np.vstack([xaxis, signal])
def create_symbol_sequence(n_symbols, skew=1):
np.random.seed(42)
data = np.random.randint(0, 4, n_symbols) / 4
data = np.pow(data, skew)
return tuple(data)
def generate_signal(data, sps):
working_data = np.diff(data, prepend=data[0])
data_padded = np.zeros(len(data) * sps)
data_padded[::sps] = working_data
data_padded = np.pad(data_padded, (0, sps // 2), mode="constant")
wavelet = generate_wavelet(sps, oversample=3)
signal = np.convolve(data_padded, wavelet)
signal = np.cumsum(signal)
signal = signal[sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
return signal
def normalization_with_noise(signal, noise=0):
if noise > 0:
awgn = np.random.normal(0, noise * (np.max(signal) - np.min(signal)), len(signal))
signal += awgn
# min-max normalization
signal = signal - np.min(signal)
signal = signal / np.max(signal)
return signal
def generate_wavelet(sps, oversample=3):
sample_points = np.linspace(
-oversample * sps,
oversample * sps,
2 * oversample * sps,
endpoint=True,
)
sigma = 0.33 / (1 * np.sqrt(2 * np.log(2))) * sps
pulse = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
return pulse
class eye_diagram:
def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4):
# data has shape [channels, 2, samples]
# each sample has a timestamp and a value
if data.ndim == 2:
data = data[np.newaxis, :, :]
self.channel_names = channel_names
self.raw_data = data
self.channels = data.shape[0]
self.n_levels = n_levels
self.eye_stats = [{"success": False} for _ in range(self.channels)]
self.horizontal_bins = horizontal_bins
self.vertical_bins = vertical_bins
self.eye_built = False
self.analyse(self.n_levels)
def generate_eye_data(self):
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
self.y_bins = np.zeros((self.channels, self.vertical_bins))
self.eye_data = np.zeros((self.channels, self.vertical_bins, self.horizontal_bins))
for i in range(self.channels):
data_min = np.min(self.raw_data[i, 1, :])
data_max = np.max(self.raw_data[i, 1, :])
self.y_bins[i] = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
t_vals = self.raw_data[i, 0, :] % 2
val_vals = self.raw_data[i, 1, :]
x_indices = np.digitize(t_vals, self.x_bins) - 1
y_indices = np.digitize(val_vals, self.y_bins[i]) - 1
np.add.at(self.eye_data[i], (y_indices, x_indices), 1)
self.eye_built = True
def plot(self, title="Eye Diagram", stats=True, show=True):
if not self.eye_built:
self.generate_eye_data()
cmap = LinearSegmentedColormap.from_list(
"eyemap",
[(0, "white"), (0.1, "blue"), (0.2, "cyan"), (0.5, "green"), (0.8, "yellow"), (0.9, "red"), (1, "magenta")],
)
if self.channels % 2 == 0:
rows = 2
cols = self.channels // 2
else:
cols = int(np.ceil(np.sqrt(self.channels)))
rows = int(np.ceil(self.channels / cols))
fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False)
fig.suptitle(title)
ax = np.atleast_1d(ax).transpose().flatten()
for i in range(self.channels):
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}")
ax[i].set_xlabel("Symbol")
ax[i].set_ylabel("Amplitude")
ax[i].grid()
ax[i].imshow(
self.eye_data[i],
origin="lower",
aspect="auto",
cmap=cmap,
extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
)
ax[i].set_xlim((self.x_bins[0], self.x_bins[-1]))
ymin = np.min(self.y_bins[:, 0])
ymax = np.max(self.y_bins[:, -1])
yspan = ymax - ymin
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
if stats and self.eye_stats[i]["success"]:
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
ax[i].set_yticks(self.eye_stats[i]["levels"])
# add arrows for amplitudes
for j in range(len(self.eye_stats[i]["amplitudes"])):
ax[i].annotate(
"",
xy=(0.05, self.eye_stats[i]["levels"][j]),
xytext=(0.05, self.eye_stats[i]["levels"][j + 1]),
arrowprops=dict(arrowstyle="<->", facecolor="black"),
)
ax[i].annotate(
f"{self.eye_stats[i]['amplitudes'][j]:.2e}",
xy=(0.06, (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2),
)
# add arrows for eye heights
for j in range(len(self.eye_stats[i]["heights"])):
try:
bot = np.max(self.eye_stats[i]["amplitude_clusters"][j])
top = np.min(self.eye_stats[i]["amplitude_clusters"][j + 1])
ax[i].annotate(
"",
xy=(self.eye_stats[i]["time_midpoint"], bot),
xytext=(self.eye_stats[i]["time_midpoint"], top),
arrowprops=dict(arrowstyle="<->", facecolor="black"),
)
ax[i].annotate(
f"{self.eye_stats[i]['heights'][j]:.2e}",
xy=(self.eye_stats[i]["time_midpoint"] + 0.015, (bot + top) / 2 + 0.04),
)
except (ValueError, IndexError):
pass
# add arrows for eye widths
for j in range(len(self.eye_stats[i]["widths"])):
try:
left = np.max(self.eye_stats[i]["time_clusters"][j][0])
right = np.min(self.eye_stats[i]["time_clusters"][j][1])
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
ax[i].annotate(
"",
xy=(left, vertical),
xytext=(right, vertical),
arrowprops=dict(arrowstyle="<->", facecolor="black"),
)
ax[i].annotate(
f"{self.eye_stats[i]['widths'][j]:.2e}",
xy=((left + right) / 2 - 0.15, vertical + 0.01),
)
except (ValueError, IndexError):
pass
# add area
for j in range(len(self.eye_stats[i]["areas"])):
horizontal = self.eye_stats[i]["time_midpoint"]
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
ax[i].annotate(
f"{self.eye_stats[i]['areas'][j]:.2e}",
xy=(horizontal + 0.035, vertical - 0.07),
)
# add min_area above the plot
ax[i].annotate(
f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
xy=(0.05, ymax + 0.05 * yspan),
# xycoords="axes fraction",
ha="left",
va="center",
)
fig.tight_layout()
if show:
plt.show()
return fig
def analyse(self, n_levels=4):
warnings.filterwarnings("error")
for i in range(self.channels):
self.eye_stats[i]["channel"] = str(i+1) if self.channel_names is None else self.channel_names[i]
try:
approx_levels = eye_diagram.approximate_levels(self.raw_data[i], n_levels)
time_bounds = eye_diagram.calculate_time_bounds(self.raw_data[i], approx_levels)
self.eye_stats[i]["time_midpoint"] = (time_bounds[0] + time_bounds[1]) / 2
self.eye_stats[i]["levels"], self.eye_stats[i]["amplitude_clusters"] = eye_diagram.calculate_levels(
self.raw_data[i], approx_levels, time_bounds
)
self.eye_stats[i]["amplitudes"] = np.diff(self.eye_stats[i]["levels"])
self.eye_stats[i]["heights"] = eye_diagram.calculate_eye_heights(
self.eye_stats[i]["amplitude_clusters"]
)
self.eye_stats[i]["widths"], self.eye_stats[i]["time_clusters"] = eye_diagram.calculate_eye_widths(
self.raw_data[i], self.eye_stats[i]["levels"]
)
# # check if time clusters are valid (upper bound > time_midpoint > lower bound)
# # if not: raise ValueError
# for j in range(len(self.eye_stats[i]['time_clusters'])):
# if not (np.max(self.eye_stats[i]['time_clusters'][j][0]) < self.eye_stats[i]["time_midpoint"] < np.min(self.eye_stats[i]['time_clusters'][j][1])):
# raise ValueError
self.eye_stats[i]["areas"] = self.eye_stats[i]["heights"] * self.eye_stats[i]["widths"]
self.eye_stats[i]["mean_area"] = np.mean(self.eye_stats[i]["areas"])
self.eye_stats[i]["min_area"] = np.min(self.eye_stats[i]["areas"])
self.eye_stats[i]["success"] = True
except (RuntimeWarning, UserWarning, ValueError):
self.eye_stats[i]["success"] = False
self.eye_stats[i]["time_midpoint"] = 0
self.eye_stats[i]["levels"] = np.zeros(n_levels)
self.eye_stats[i]["amplitude_clusters"] = []
self.eye_stats[i]["amplitudes"] = np.zeros(n_levels - 1)
self.eye_stats[i]["heights"] = np.zeros(n_levels - 1)
self.eye_stats[i]["widths"] = np.zeros(n_levels - 1)
self.eye_stats[i]["areas"] = np.zeros(n_levels - 1)
self.eye_stats[i]["mean_area"] = 0
self.eye_stats[i]["min_area"] = 0
warnings.resetwarnings()
@staticmethod
def approximate_levels(data, levels):
amplitudes = data[1]
grouping_data = amplitudes.reshape(-1, 1)
kmeans, clusters = eye_diagram.kmeans_cluster(grouping_data, levels)
centroids = np.zeros(levels)
for i in range(levels):
centroids[i] = eye_diagram.shorth(clusters[i])
return np.sort(centroids)
@staticmethod
def kmeans_cluster(data, levels):
working_data = data.reshape(-1, 1)
# initial = np.linspace(np.min(working_data), np.max(working_data), levels).reshape(-1, 1)
kmeans = kmeans2(working_data, levels, iter=100, minit="++")
order = np.argsort(kmeans[0].squeeze())
kmeans[0][:] = kmeans[0][order]
order = np.argsort(order)
kmeans[1][:] = order[kmeans[1]]
clusters = [[] for _ in range(levels)]
for i, elem in enumerate(data):
clusters[kmeans[1][i]].append(elem.squeeze())
clusters = [np.array(cluster) for cluster in clusters]
# clusters = [clusters[i] for i in order]
return kmeans, clusters
@staticmethod
def shorth(data):
working_data = np.sort(data)
n = len(working_data)
h = n // 2 + 1
min_diff = np.inf
interval = np.zeros(2)
for i in range(n - h):
diff = working_data[i + h] - working_data[i]
if diff < min_diff:
min_diff = diff
interval = [working_data[i], working_data[i + h]]
return np.mean(interval)
@staticmethod
def calculate_time_bounds(data, level_centroids):
n_levels = 2
# prepare data
selection_range = eye_diagram.calc_selection_range(level_centroids[1:3], 0.01)
# times = np.arange(0, len(data), dtype=np.float32)
times, amplitudes = data
grouping_data = times[(amplitudes > selection_range[0]) & (amplitudes < selection_range[1])]
grouping_data = grouping_data % 2
grouping_data = grouping_data.reshape(-1, 1)
kmeans, clusters = eye_diagram.kmeans_cluster(grouping_data, n_levels)
# time_midpoint = (np.min(clusters[1]) + np.max(clusters[0]))/2
# # check if time clusters are valid (upper bound > time_midpoint > lower bound)
# # if not: raise ValueError
# if not (np.max(clusters[0]) < time_midpoint < np.min(clusters[1])):
# raise ValueError
return np.min(clusters[1]), np.max(clusters[0])
@staticmethod
def calc_selection_range(data, tolerance):
middle = np.mean(data)
tol = tolerance * np.abs(np.diff(data))
return (middle - tol, middle + tol)
@staticmethod
def calculate_levels(data, level_centroids, time_bounds):
selection_range = eye_diagram.calc_selection_range(time_bounds, 0.025)
times, amplitudes = data
indices = np.arange(0, len(times))
filtered_time = indices[((times % 2) > selection_range[0]) & ((times % 2) < selection_range[1])]
filtered_data = amplitudes[filtered_time]
vertical_bounds = np.array([
-np.inf,
*[(level_centroids[i] + level_centroids[i + 1]) / 2 for i in range(len(level_centroids) - 1)],
np.inf,
])
central_level_means = np.zeros(len(level_centroids))
amplitude_clusters = []
for i in range(len(level_centroids)):
amplitude_filtered_data = filtered_data[
(filtered_data > vertical_bounds[i]) & (filtered_data < vertical_bounds[i + 1])
]
amplitude_clusters.append(amplitude_filtered_data)
central_level_means[i] = np.mean(amplitude_filtered_data)
# # check if amplitude clusters are valid (upper bound > level_midpoint > lower bound)
# # if not: raise ValueError
# for j in range(len(amplitude_clusters)):
# level_midpoint = (central_level_means[j] + central_level_means[j+1]) / 2
# if not (np.max(amplitude_clusters[0]) < level_midpoint < np.min(amplitude_clusters[1])):
# raise ValueError
return central_level_means, amplitude_clusters
@staticmethod
def calculate_eye_heights(amplitude_clusters):
eye_heights = np.zeros(len(amplitude_clusters) - 1)
for i in range(len(amplitude_clusters) - 1):
eye_heights[i] = np.min(amplitude_clusters[i + 1]) - np.max(amplitude_clusters[i])
return eye_heights
@staticmethod
def calculate_eye_widths(data, central_level_means):
n_levels = len(central_level_means)
widths = np.zeros(n_levels - 1)
times, amplitudes = data
clusters = []
for i in range(n_levels - 1):
selection_range = eye_diagram.calc_selection_range(
[central_level_means[i], central_level_means[i + 1]], 0.01
)
grouping_data = times[(amplitudes > selection_range[0]) & (amplitudes < selection_range[1])]
grouping_data = grouping_data % 2
grouping_data = grouping_data.reshape(-1, 1)
kmeans, cluster = eye_diagram.kmeans_cluster(grouping_data, 2)
clusters.append(cluster)
widths[i] = np.min(cluster[1]) - np.max(cluster[0])
...
return widths, clusters
if __name__ == "__main__":
length = int(2**14)
# data = generate_sample_data(length, noise=1)
# data1 = generate_sample_data(length, noise=0.01)
# data2 = generate_sample_data(length, noise=0.01, skew=1.2)
# data3 = generate_sample_data(length, noise=0.02)
# data = np.stack([data, data1, data2, data3])
data = generate_sample_data(length, noise=0.005)
eye = eye_diagram(data, horizontal_bins=256, vertical_bins=256)
attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths", "area", "mean_area", "min_area")
for i, channel in enumerate(eye.eye_stats):
print(f"Channel {i}")
print_data = {attr: channel[attr] for attr in attrs}
print(print_data)
eye.plot()