diff --git a/src/single-core-regen/hypertraining/hypertraining.py b/src/single-core-regen/hypertraining/hypertraining.py
index bed9bcc..758331e 100644
--- a/src/single-core-regen/hypertraining/hypertraining.py
+++ b/src/single-core-regen/hypertraining/hypertraining.py
@@ -1,7 +1,16 @@
import copy
from datetime import datetime
from pathlib import Path
+import random
from typing import Literal
+import matplotlib
+from matplotlib.colors import LinearSegmentedColormap
+import torch.nn.utils.parametrize
+
+try:
+ matplotlib.use("cairo")
+except ImportError:
+ matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
@@ -11,7 +20,7 @@ import optuna
import warnings
import torch
-import torch.nn as nn
+# import torch.nn as nn
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
import torch.optim as optim
@@ -19,27 +28,9 @@ import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
-# from rich.progress import (
-# Progress,
-# TextColumn,
-# BarColumn,
-# TaskProgressColumn,
-# TimeRemainingColumn,
-# MofNCompleteColumn,
-# TimeElapsedColumn,
-# )
-# from rich.console import Console
-# from rich import print as rprint
-
import multiprocessing
from util.datasets import FiberRegenerationDataset
-
-# from util.optuna_helpers import (
-# suggest_categorical_optional, # noqa: F401
-# suggest_float_optional, # noqa: F401
-# suggest_int_optional, # noqa: F401
-# )
from util.optuna_helpers import install_optional_suggests
import util
@@ -65,7 +56,6 @@ class HyperTraining:
model_settings,
optimizer_settings,
optuna_settings,
- # console=None,
):
self.global_settings: GlobalSettings = global_settings
self.data_settings: DataSettings = data_settings
@@ -75,11 +65,8 @@ class HyperTraining:
self.optuna_settings: OptunaSettings = optuna_settings
self.processes = None
- # self.console = console or Console()
-
- # set some extra settings to make the code more readable
self._extra_optuna_settings()
- self.stop_study = True
+ self.stop_study = False
def setup_tb_writer(self, study_name=None, append=None):
log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name)
@@ -229,7 +216,7 @@ class HyperTraining:
self.optuna_settings._parallel = self.optuna_settings._n_threads > 1
def define_model(self, trial: optuna.Trial, writer=None):
- n_layers = trial.suggest_int_optional("model_n_hidden_layers", self.model_settings.n_hidden_layers)
+ n_hidden_layers = trial.suggest_int_optional("model_n_hidden_layers", self.model_settings.n_hidden_layers)
input_dim = trial.suggest_int_optional(
"model_input_dim",
@@ -245,32 +232,41 @@ class HyperTraining:
dtype = getattr(torch, dtype)
afunc = trial.suggest_categorical_optional("model_activation_func", self.model_settings.model_activation_func)
- # T0 = trial.suggest_float_optional("T0", self.model_settings.satabsT0 , log=True)
+ afunc = getattr(util.complexNN, afunc)
+ layer_func = trial.suggest_categorical_optional("model_layer_function", self.model_settings.model_layer_function)
+ layer_func = getattr(util.complexNN, layer_func)
+ layer_parametrizations = self.model_settings.model_layer_parametrizations
- layers = []
- last_dim = input_dim
- n_nodes = last_dim
- for i in range(n_layers):
+ scale_layers = trial.suggest_categorical_optional("model_enable_scale_layers", self.model_settings.scale)
+
+
+ hidden_dims = []
+ for i in range(n_hidden_layers):
if hidden_dim_override := self.model_settings.overrides.get(f"n_hidden_nodes_{i}", False):
- hidden_dim = trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override)
+ hidden_dims.append(trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override))
else:
- hidden_dim = trial.suggest_int_optional(
+ hidden_dims.append(trial.suggest_int_optional(
f"model_hidden_dim_{i}",
self.model_settings.n_hidden_nodes,
- )
- layers.append(util.complexNN.ONNRect(last_dim, hidden_dim, dtype=dtype))
- last_dim = hidden_dim
- layers.append(getattr(util.complexNN, afunc)())
- n_nodes += last_dim
-
- layers.append(util.complexNN.ONNRect(last_dim, self.model_settings.output_dim, dtype=dtype))
-
- model = nn.Sequential(*layers)
+ ))
+
+ model_kwargs = {
+ "dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
+ "layer_function": layer_func,
+ "layer_parametrizations": layer_parametrizations,
+ "activation_function": afunc,
+ "dtype": dtype,
+ "droupout_prob": self.model_settings.dropout_prob,
+ "scale": scale_layers,
+ }
+
+ model = util.complexNN.regenerator(**model_kwargs)
+ n_nodes = sum(hidden_dims)
if writer is not None:
writer.add_graph(model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False)
- n_params = sum(p.numel() for p in model.parameters())
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
trial.set_user_attr("model_n_params", n_params)
trial.set_user_attr("model_n_nodes", n_nodes)
@@ -384,7 +380,8 @@ class HyperTraining:
running_loss2 = 0.0
running_loss = 0.0
model.train()
- for batch_idx, (x, y) in enumerate(train_loader):
+ loader_len = len(train_loader)
+ for batch_idx, (x, y, _) in enumerate(train_loader):
if batch_idx >= self.optuna_settings._n_train_batches:
break
model.zero_grad(set_to_none=True)
@@ -408,14 +405,14 @@ class HyperTraining:
writer.add_scalar(
"training loss",
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
- epoch * min(len(train_loader), self.optuna_settings._n_train_batches) + batch_idx,
+ epoch * min(loader_len, self.optuna_settings._n_train_batches) + batch_idx,
)
running_loss2 = 0.0
# if enable_progress:
# progress.stop()
- return running_loss / min(len(train_loader), self.optuna_settings._n_train_batches)
+ return running_loss / min(loader_len, self.optuna_settings._n_train_batches)
def eval_model(
self,
@@ -446,9 +443,8 @@ class HyperTraining:
model.eval()
running_error = 0
- running_error_2 = 0
with torch.no_grad():
- for batch_idx, (x, y) in enumerate(valid_loader):
+ for batch_idx, (x, y, _) in enumerate(valid_loader):
if batch_idx >= self.optuna_settings._n_valid_batches:
break
x, y = (
@@ -459,19 +455,6 @@ class HyperTraining:
error = util.complexNN.complex_mse_loss(y_pred, y)
error_value = error.item()
running_error += error_value
- running_error_2 += error_value
-
- # if enable_progress:
- # progress.update(task, advance=1, description=f"{error_value:.3e}")
-
- if writer is not None:
- if batch_idx % self.pytorch_settings.write_every == 0:
- writer.add_scalar(
- "eval loss",
- running_error_2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
- epoch * min(len(valid_loader), self.optuna_settings._n_valid_batches) + batch_idx,
- )
- running_error_2 = 0.0
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
@@ -488,38 +471,73 @@ class HyperTraining:
),
epoch + 1,
)
+ writer.add_figure(
+ "eye diagram",
+ self.plot_model_response(
+ trial,
+ model=self.model,
+ title_append=title_append,
+ subtitle=subtitle,
+ show=False,
+ mode="eye",
+ ),
+ epoch + 1,
+ )
+
+ writer.add_figure(
+ "powers",
+ self.plot_model_response(
+ trial,
+ model=self.model,
+ title_append=title_append,
+ subtitle=subtitle,
+ mode="powers",
+ show=False,
+ ),
+ epoch + 1,
+ )
# if enable_progress:
# progress.stop()
return running_error
- def run_model(self, model, loader):
+ def run_model(self, model, loader, trace_powers=False):
model.eval()
- xs = []
- ys = []
- y_preds = []
+ fiber_out = []
+ fiber_in = []
+ regen = []
+ timestamps = []
+
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
- for x, y in loader:
+ for x, y, timestamp in loader:
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
- y_pred = model(x).cpu()
+ if trace_powers:
+ y_pred, powers = model(x, trace_powers).cpu()
+ else:
+ y_pred = model(x, trace_powers).cpu()
# x = x.cpu()
# y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2)
x = x.view(x.shape[0], -1, 2)
- xs.append(x[:, 0, :].squeeze())
- ys.append(y.squeeze())
- y_preds.append(y_pred.squeeze())
+ # timestamp = timestamp.view(-1, 1)
+ fiber_out.append(x[:, x.shape[1] // 2, :].squeeze())
+ fiber_in.append(y.squeeze())
+ regen.append(y_pred.squeeze())
+ timestamps.append(timestamp.squeeze())
- xs = torch.vstack(xs).cpu()
- ys = torch.vstack(ys).cpu()
- y_preds = torch.vstack(y_preds).cpu()
- return ys, xs, y_preds
+ fiber_out = torch.vstack(fiber_out).cpu()
+ fiber_in = torch.vstack(fiber_in).cpu()
+ regen = torch.vstack(regen).cpu()
+ timestamps = torch.concat(timestamps).cpu()
+ if trace_powers:
+ return fiber_in, fiber_out, regen, timestamps, powers
+ return fiber_in, fiber_out, regen, timestamps
def objective(self, trial: optuna.Trial, plot_before=False):
if self.stop_study:
@@ -544,7 +562,32 @@ class HyperTraining:
model=model,
title_append=title_append,
subtitle=subtitle,
- show=plot_before,
+ show=False,
+ ),
+ 0,
+ )
+ writer.add_figure(
+ "eye diagram",
+ self.plot_model_response(
+ trial,
+ model=self.model,
+ title_append=title_append,
+ subtitle=subtitle,
+ mode="eye",
+ show=False,
+ ),
+ 0,
+ )
+
+ writer.add_figure(
+ "powers",
+ self.plot_model_response(
+ trial,
+ model=self.model,
+ title_append=title_append,
+ subtitle=subtitle,
+ mode="powers",
+ show=False,
),
0,
)
@@ -609,7 +652,9 @@ class HyperTraining:
return error
- def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
+ def _plot_model_response_eye(
+ self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
+ ):
if sps is None:
raise ValueError("sps must be provided")
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
@@ -624,27 +669,84 @@ class HyperTraining:
if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))]
- fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True)
+ x_bins = np.linspace(0, 2, 2 * sps, endpoint=False)
+ y_bins = np.zeros((2 * len(signals), 1000))
+ eye_data = np.zeros((2 * len(signals), 1000, 2 * sps))
+ # signals = [signal.cpu().numpy() for signal in signals]
+ for i in range(len(signals) * 2):
+ eye_signal = signals[i // 2][:, i % 2] # x, y, x, y, ...
+ eye_signal = np.real(np.square(np.abs(eye_signal)))
+ data_min = np.min(eye_signal)
+ data_max = np.max(eye_signal)
+ y_bins[i] = np.linspace(data_min, data_max, 1000, endpoint=False)
+ for j in range(len(timestamps)):
+ t = timestamps[j] / sps
+ val = eye_signal[j]
+ x = np.digitize(t % 2, x_bins) - 1
+ y = np.digitize(val, y_bins[i]) - 1
+ eye_data[i][y][x] += 1
+
+ cmap = LinearSegmentedColormap.from_list(
+ "eyemap",
+ [
+ (0, "white"),
+ (0.001, "dodgerblue"),
+ (0.1, "blue"),
+ (0.2, "cyan"),
+ (0.5, "lime"),
+ (0.8, "gold"),
+ (1, "red"),
+ ],
+ )
+
+ # ordering = np.argsort(timestamps)
+ # signals = [signal[ordering] for signal in signals]
+ # timestamps = timestamps[ordering]
+
+ fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
+ fig.set_figwidth(18)
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
- xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
- for j, (label, signal) in enumerate(zip(labels, signals)):
+ # xaxis = timestamps / sps
+ # xaxis = np.arange(2 * sps) / sps
+ for j, label in enumerate(labels):
+ x = eye_data[2 * j]
+ y = eye_data[2 * j + 1]
+ # x, y = signal.T
# signal = signal.cpu().numpy()
- for i in range(len(signal) // sps - 1):
- x, y = signal[i * sps : (i + 2) * sps].T
- axs[0, j].plot(xaxis, np.abs(x) ** 2, color="C0", alpha=0.02)
- axs[1, j].plot(xaxis, np.abs(y) ** 2, color="C0", alpha=0.02)
- axs[0, j].set_title(label + " x")
- axs[1, j].set_title(label + " y")
- axs[0, j].set_xlabel("Symbol")
- axs[1, j].set_xlabel("Symbol")
- axs[0, j].set_ylabel("normalized power")
- axs[1, j].set_ylabel("normalized power")
+ # for i in range(len(signal) // sps - 1):
+ # x, y = signal[i * sps : (i + 2) * sps].T
+ # axs[0 + 2 * j].scatter((timestamps/sps) % 2, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1)
+ # axs[1 + 2 * j].scatter((timestamps/sps) % 2, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1)
+ axs[0 + 2 * j].imshow(
+ x, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j][0], y_bins[2 * j][-1]]
+ )
+ axs[1 + 2 * j].imshow(
+ y, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j + 1][0], y_bins[2 * j + 1][-1]]
+ )
+ axs[0 + 2 * j].set_xlim((x_bins[0], x_bins[-1]))
+ axs[1 + 2 * j].set_xlim((x_bins[0], x_bins[-1]))
+ ymin = np.min(y_bins[:, 0])
+ ymax = np.max(y_bins[:, -1])
+ ydiff = ymax - ymin
+ axs[0 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff))
+ axs[1 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff))
+ axs[0 + 2 * j].set_title(label + " x")
+ axs[1 + 2 * j].set_title(label + " y")
+ axs[0 + 2 * j].set_xlabel("Symbol")
+ axs[1 + 2 * j].set_xlabel("Symbol")
+ axs[0 + 2 * j].set_box_aspect(1)
+ axs[1 + 2 * j].set_box_aspect(1)
+ axs[0].set_ylabel("normalized power")
+ fig.tight_layout()
+ # axs[1+2*len(labels)-1].set_ylabel("normalized power")
if show:
plt.show()
return fig
- def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
+ def _plot_model_response_head(
+ self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
+ ):
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
labels = [labels]
else:
@@ -657,19 +759,31 @@ class HyperTraining:
if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))]
+ ordering = np.argsort(timestamps)
+ signals = [signal[ordering] for signal in signals]
+ timestamps = timestamps[ordering]
+
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
- fig.set_size_inches(18, 6)
+ fig.set_figwidth(18)
+ fig.set_figheight(4)
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
for i, ax in enumerate(axs):
+ ax: plt.Axes
for signal, label in zip(signals, labels):
if sps is not None:
- xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False)
+ xaxis = timestamps / sps
else:
- xaxis = np.arange(len(signal))
+ xaxis = timestamps
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
ax.set_xlabel("Sample" if sps is None else "Symbol")
ax.set_ylabel("normalized power")
+ ax.minorticks_on()
+ ax.tick_params(axis="y", which="minor", left=False, right=False)
+ ax.grid(which="major", axis="x")
+ ax.grid(which="minor", axis="x", linestyle=":")
+ ax.grid(which="major", axis="y")
ax.legend(loc="upper right")
+ fig.tight_layout()
if show:
plt.show()
return fig
@@ -680,22 +794,52 @@ class HyperTraining:
model=None,
title_append="",
subtitle="",
- mode: Literal["eye", "head"] = "head",
- show=True,
+ mode: Literal["eye", "head", "powers"] = "head",
+ show=False,
):
+ if mode == "powers":
+ input_data = torch.ones(
+ 1, 2 * self.data_settings.output_size, dtype=getattr(torch, self.data_settings.dtype)
+ ).to(self.pytorch_settings.device)
+ model = model.to(self.pytorch_settings.device)
+ model.eval()
+ with torch.no_grad():
+ _, powers = model(input_data, trace_powers=True)
+
+ powers = [power.item() for power in powers]
+ layer_names = ["input", *[str(x).split("(")[0] for x in model._layers._modules.values()]]
+
+ # remove dropout layers
+ mask = [1 if "Dropout" not in layer_name else 0 for layer_name in layer_names]
+ layer_names = [layer_name for layer_name, m in zip(layer_names, mask) if m]
+ powers = [power for power, m in zip(powers, mask) if m]
+
+ fig = self._plot_model_response_powers(
+ powers, layer_names, title_append=title_append, subtitle=subtitle, show=show
+ )
+ return fig
+
data_settings_backup = copy.deepcopy(self.data_settings)
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
- self.data_settings.drop_first = 100*128
+ self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
self.data_settings.shuffle = False
self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = (
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
)
- plot_loader, _ = self.get_sliced_data(trial, override={"num_symbols": self.pytorch_settings.batchsize})
+ config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path
+ fiber_length = int(float(str(config_path).split('-')[-7])/1000)
+ plot_loader, _ = self.get_sliced_data(
+ trial,
+ override={
+ "num_symbols": self.pytorch_settings.batchsize,
+ "config_path": config_path,
+ }
+ )
self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup
- fiber_in, fiber_out, regen = self.run_model(model, plot_loader)
+ fiber_in, fiber_out, regen, timestamps = self.run_model(model, plot_loader)
fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2)
regen = regen.view(-1, 2)
@@ -703,6 +847,7 @@ class HyperTraining:
fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy()
regen = regen.numpy()
+ timestamps = timestamps.numpy()
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
@@ -713,9 +858,10 @@ class HyperTraining:
fiber_in,
fiber_out,
regen,
+ timestamps=timestamps,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
- title_append=title_append,
+ title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
@@ -725,9 +871,10 @@ class HyperTraining:
fiber_in,
fiber_out,
regen,
+ timestamps=timestamps,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
- title_append=title_append,
+ title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
@@ -739,7 +886,7 @@ class HyperTraining:
@staticmethod
def build_title(trial: optuna.trial.Trial):
- title_append = f"for trial {trial.number}"
+ title_append = f"at epoch {trial.user_attrs.get("epoch", -1)} for trial {trial.number}"
model_n_hidden_layers = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_n_hidden_layers", 0)
input_dim = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_input_dim", 0)
model_dims = [
diff --git a/src/single-core-regen/hypertraining/training.py b/src/single-core-regen/hypertraining/training.py
index fc27098..71b7bef 100644
--- a/src/single-core-regen/hypertraining/training.py
+++ b/src/single-core-regen/hypertraining/training.py
@@ -16,7 +16,7 @@ import matplotlib.pyplot as plt
import numpy as np
import torch
-import torch.nn as nn
+# import torch.nn as nn
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
import torch.optim as optim
@@ -47,88 +47,6 @@ from .settings import (
)
-class regenerator(nn.Module):
- def __init__(
- self,
- *dims,
- layer_function=util.complexNN.ONN,
- layer_kwargs: dict | None = None,
- layer_parametrizations: list[dict] = None,
- # [
- # {
- # "tensor_name": "weight",
- # "parametrization": util.complexNN.Unitary,
- # },
- # {
- # "tensor_name": "scale",
- # "parametrization": util.complexNN.Clamp,
- # },
- # ],
- activation_function=util.complexNN.Pow,
- dtype=torch.float64,
- dropout_prob=0.01,
- scale=False,
- **kwargs,
- ):
- super(regenerator, self).__init__()
- if len(dims) == 0:
- try:
- dims = kwargs["dims"]
- except KeyError:
- raise ValueError("dims must be provided")
- self._n_hidden_layers = len(dims) - 2
- self._layers = nn.Sequential()
- if layer_kwargs is None:
- layer_kwargs = {}
- # self.powers = []
-
- for i in range(self._n_hidden_layers + 1):
- if scale:
- self._layers.append(util.complexNN.Scale(dims[i]))
- self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs))
- if i < self._n_hidden_layers:
- if dropout_prob is not None:
- self._layers.append(util.complexNN.DropoutComplex(p=dropout_prob))
- self._layers.append(activation_function(bias=True, size=dims[i + 1]))
-
- self._layers.append(util.complexNN.Scale(dims[-1]))
-
- # add parametrizations
- if layer_parametrizations is not None:
- for layer in self._layers:
- for layer_parametrization in layer_parametrizations:
- tensor_name = layer_parametrization.get("tensor_name", None)
- parametrization = layer_parametrization.get("parametrization", None)
- param_kwargs = layer_parametrization.get("kwargs", {})
- if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None:
- parametrization(layer, tensor_name, **param_kwargs)
-
- # def __call__(self, input_x, **kwargs):
- # return self.forward(input_x, **kwargs)
-
- def forward(self, input_x, trace_powers=False):
- x = input_x
-
- if trace_powers:
- powers = [x.abs().square().sum()]
-
- # check if tracing
- if torch.jit.is_tracing():
- for layer in self._layers:
- x = layer(x)
- if trace_powers:
- powers.append(x.abs().square().sum())
- else:
- # with torch.nn.utils.parametrize.cached():
- for layer in self._layers:
- x = layer(x)
- if trace_powers:
- powers.append(x.abs().square().sum())
- if trace_powers:
- return x, powers
- return x
-
-
def traverse_dict_update(target, source):
for k, v in source.items():
if isinstance(v, dict):
@@ -164,7 +82,7 @@ class Trainer:
ModelSettings,
OptimizerSettings,
PytorchSettings,
- regenerator,
+ util.complexNN.regenerator,
torch.nn.utils.parametrizations.orthogonal,
])
if self.resume:
@@ -264,7 +182,7 @@ class Trainer:
dtype = self.model_kwargs["dtype"]
# dims = self.model_kwargs.pop("dims")
- self.model = regenerator(**self.model_kwargs)
+ self.model = util.complexNN.regenerator(**self.model_kwargs)
if self.writer is not None:
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
@@ -364,7 +282,7 @@ class Trainer:
task = progress.add_task("-.---e--", total=len(train_loader))
progress.start()
- # running_loss2 = 0.0
+ running_loss2 = 0.0
running_loss = 0.0
self.model.train()
loader_len = len(train_loader)
@@ -379,23 +297,24 @@ class Trainer:
loss_value = loss.item()
loss.backward()
optimizer.step()
- # running_loss2 += loss_value
+ running_loss2 += loss_value
running_loss += loss_value
if enable_progress:
- progress.update(task, advance=1, description=f"{running_loss/(batch_idx+1):.3e}")
+ progress.update(task, advance=1, description=f"{loss_value:.3e}")
if batch_idx % self.pytorch_settings.write_every == 0:
self.writer.add_scalar(
"training loss",
- running_loss / (batch_idx + 1),
- epoch * loader_len + batch_idx,
+ running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
+ epoch + batch_idx/loader_len,
)
+ running_loss2 = 0.0
if enable_progress:
progress.stop()
- return running_loss / (batch_idx + 1)
+ return running_loss / len(train_loader)
def eval_model(self, valid_loader, epoch, enable_progress=True):
if enable_progress:
@@ -418,7 +337,7 @@ class Trainer:
self.model.eval()
running_error = 0
with torch.no_grad():
- for batch_idx, (x, y, _) in enumerate(valid_loader):
+ for _, (x, y, _) in enumerate(valid_loader):
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
@@ -429,9 +348,9 @@ class Trainer:
running_error += error_value
if enable_progress:
- progress.update(task, advance=1, description=f"{error_value/(batch_idx+1):.3e}")
+ progress.update(task, advance=1, description=f"{error_value:.3e}")
- running_error /= (batch_idx+1)
+ running_error = running_error/len(valid_loader)
self.writer.add_scalar(
"eval loss",
@@ -858,7 +777,7 @@ class Trainer:
self.pytorch_settings.batchsize = (
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
)
- config_path = random.choice(self.data_settings.config_path)
+ config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path
fiber_length = int(float(str(config_path).split('-')[-7])/1000)
plot_loader, _ = self.get_sliced_data(
override={
diff --git a/src/single-core-regen/regen_no_hyper.py b/src/single-core-regen/regen_no_hyper.py
index 70aee0b..a68022b 100644
--- a/src/single-core-regen/regen_no_hyper.py
+++ b/src/single-core-regen/regen_no_hyper.py
@@ -1,3 +1,4 @@
+from pathlib import Path
import matplotlib
import numpy as np
import torch
@@ -22,8 +23,8 @@ global_settings = GlobalSettings(
)
data_settings = DataSettings(
- # config_path="data/*-128-16384-50000-0-0-17-0-PAM4-0.ini",
- config_path=[f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in (40000, 50000, 60000)],
+ config_path="data/20241204-13*-128-16384-100000-0-0-17-0-PAM4-0.ini",
+ # config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232
@@ -52,8 +53,8 @@ model_settings = ModelSettings(
output_dim=2,
n_hidden_layers=4,
overrides={
- "n_hidden_nodes_0": 4,
- "n_hidden_nodes_1": 4,
+ "n_hidden_nodes_0": 8,
+ "n_hidden_nodes_1": 8,
"n_hidden_nodes_2": 4,
"n_hidden_nodes_3": 4,
},
@@ -61,7 +62,7 @@ model_settings = ModelSettings(
dropout_prob=0.01,
model_layer_function="ONNRect",
model_layer_kwargs={"square": True},
- scale=True,
+ scale=False,
model_layer_parametrizations=[
{
"tensor_name": "weight",
@@ -113,7 +114,7 @@ model_settings = ModelSettings(
optimizer_settings = OptimizerSettings(
optimizer="AdamW",
optimizer_kwargs={
- "lr": 0.05,
+ "lr": 0.01,
"amsgrad": True,
# "weight_decay": 1e-7,
},
@@ -142,8 +143,9 @@ def save_dict_to_file(dictionary, filename):
json.dump(dictionary, f, indent=4)
-def sweep_lengths(*lengths, model=None):
+def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
assert model is not None, "Model must be provided."
+ assert data_glob is not None, "Data glob must be provided."
model = model
fiber_ins = {}
@@ -151,19 +153,31 @@ def sweep_lengths(*lengths, model=None):
regens = {}
timestampss = {}
- for length in lengths:
- trainer = Trainer(
+ trainer = Trainer(
checkpoint_path=model,
- settings_override={
- "data_settings": {
- "config_path": f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini",
- "train_split": 1,
- "shuffle": True,
- }
- },
)
- trainer.define_model()
- loader, _ = trainer.get_sliced_data()
+ trainer.define_model()
+
+ for length in lengths:
+ data_glob_length = data_glob.replace("{length}", str(length))
+ files = list(Path.cwd().glob(data_glob_length))
+ if len(files) == 0:
+ continue
+ if strategy == "newest":
+ sorted_kwargs = {
+ 'key': lambda x: x.stat().st_mtime,
+ 'reverse': True,
+ }
+ elif strategy == "oldest":
+ sorted_kwargs = {
+ 'key': lambda x: x.stat().st_mtime,
+ 'reverse': False,
+ }
+ else:
+ raise ValueError(f"Unknown strategy {strategy}.")
+ file = sorted(files, **sorted_kwargs)[0]
+
+ loader, _ = trainer.get_sliced_data(override={"config_path": file})
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
fiber_ins[length] = fiber_in
@@ -171,17 +185,23 @@ def sweep_lengths(*lengths, model=None):
regens[length] = regen
timestampss[length] = timestamps
- data = torch.zeros(2 * len(lengths), 2, fiber_out.shape[0])
- channel_names = ["" for _ in range(2 * len(lengths))]
+ data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0])
+ channel_names = ["" for _ in range(2 * len(timestampss.keys())+2)]
- for li, length in enumerate(lengths):
- data[2 * li, 0, :] = timestampss[length] / 128
- data[2 * li, 1, :] = regens[length][:, 0].abs().square()
- data[2 * li + 1, 0, :] = timestampss[length] / 128
- data[2 * li + 1, 1, :] = regens[length][:, 1].abs().square()
+ data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128
+ data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square()
- channel_names[2 * li] = f"regen x {length}"
- channel_names[2 * li + 1] = f"regen y {length}"
+ channel_names[1] = "fiber in x"
+
+
+ for li, length in enumerate(timestampss.keys()):
+ data[2+2 * li, 0, :] = timestampss[length] / 128
+ data[2+2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
+ data[2+2 * li + 1, 0, :] = timestampss[length] / 128
+ data[2+2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
+
+ channel_names[2+2 * li+1] = f"regen x {length}"
+ channel_names[2+2 * li] = f"fiber out x {length}"
# get current backend
backend = matplotlib.get_backend()
@@ -189,28 +209,30 @@ def sweep_lengths(*lengths, model=None):
matplotlib.use("TkCairo")
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
- print_attrs = ("channel", "success", "min_area")
+ print_attrs = ("channel_name", "success", "min_area")
with np.printoptions(precision=3, suppress=True, formatter={'float': '{:0.3e}'.format}):
for result in eye.eye_stats:
print_dict = {attr: result[attr] for attr in print_attrs}
rprint(print_dict)
rprint()
- eye.plot()
+ eye.plot(all_stats=False)
matplotlib.use(backend)
if __name__ == "__main__":
- # sweep_lengths(30000, 40000, 50000, 60000, 70000, model=".models/best_20241202_143149.tar")
+ lengths = range(90000, 100000+10000, 10000)
+ # lengths = [100000]
+ sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest")
- trainer = Trainer(
- global_settings=global_settings,
- data_settings=data_settings,
- pytorch_settings=pytorch_settings,
- model_settings=model_settings,
- optimizer_settings=optimizer_settings,
- # checkpoint_path=".models/best_20241202_143149.tar",
- # 20241202_143149
- )
- trainer.train()
\ No newline at end of file
+ # trainer = Trainer(
+ # global_settings=global_settings,
+ # data_settings=data_settings,
+ # pytorch_settings=pytorch_settings,
+ # model_settings=model_settings,
+ # optimizer_settings=optimizer_settings,
+ # # checkpoint_path=".models/best_20241202_143149.tar",
+ # # 20241202_143149
+ # )
+ # trainer.train()
\ No newline at end of file
diff --git a/src/single-core-regen/sliced_dataset_test.py b/src/single-core-regen/sliced_dataset_test.py
index 284885f..d0160b7 100644
--- a/src/single-core-regen/sliced_dataset_test.py
+++ b/src/single-core-regen/sliced_dataset_test.py
@@ -39,7 +39,7 @@ import numpy as np
if __name__ == "__main__":
- dataset = FiberRegenerationDataset("data/20241115-175517-128-16384-10000-0-0-17-0-PAM4-0.ini", symbols=13, drop_first=100, output_dim=26, num_symbols=100)
+ dataset = FiberRegenerationDataset("data/202412*-128-16384-50000-0-0-17-0-PAM4-0.ini", symbols=13, drop_first=100, output_dim=26, num_symbols=100)
loader = DataLoader(dataset, batch_size=10, shuffle=True)
diff --git a/src/single-core-regen/util/complexNN.py b/src/single-core-regen/util/complexNN.py
index c47fcd4..a8cd75e 100644
--- a/src/single-core-regen/util/complexNN.py
+++ b/src/single-core-regen/util/complexNN.py
@@ -569,6 +569,78 @@ class ZReLU(nn.Module):
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2)
else:
return torch.relu(x)
+
+
+class regenerator(nn.Module):
+ def __init__(
+ self,
+ *dims,
+ layer_function=ONN,
+ layer_kwargs: dict | None = None,
+ layer_parametrizations: list[dict] = None,
+ activation_function=Pow,
+ dtype=torch.float64,
+ dropout_prob=0.01,
+ scale=False,
+ **kwargs,
+ ):
+ super(regenerator, self).__init__()
+ if len(dims) == 0:
+ try:
+ dims = kwargs["dims"]
+ except KeyError:
+ raise ValueError("dims must be provided")
+ self._n_hidden_layers = len(dims) - 2
+ self._layers = nn.Sequential()
+ if layer_kwargs is None:
+ layer_kwargs = {}
+ # self.powers = []
+
+ for i in range(self._n_hidden_layers + 1):
+ if scale:
+ self._layers.append(Scale(dims[i]))
+ self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs))
+ if i < self._n_hidden_layers:
+ if dropout_prob is not None:
+ self._layers.append(DropoutComplex(p=dropout_prob))
+ self._layers.append(activation_function(bias=True, size=dims[i + 1]))
+
+ self._layers.append(Scale(dims[-1]))
+
+ # add parametrizations
+ if layer_parametrizations is not None:
+ for layer in self._layers:
+ for layer_parametrization in layer_parametrizations:
+ tensor_name = layer_parametrization.get("tensor_name", None)
+ parametrization = layer_parametrization.get("parametrization", None)
+ param_kwargs = layer_parametrization.get("kwargs", {})
+ if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None:
+ parametrization(layer, tensor_name, **param_kwargs)
+
+ # def __call__(self, input_x, **kwargs):
+ # return self.forward(input_x, **kwargs)
+
+ def forward(self, input_x, trace_powers=False):
+ x = input_x
+
+ if trace_powers:
+ powers = [x.abs().square().sum()]
+
+ # check if tracing
+ if torch.jit.is_tracing():
+ for layer in self._layers:
+ x = layer(x)
+ if trace_powers:
+ powers.append(x.abs().square().sum())
+ else:
+ # with torch.nn.utils.parametrize.cached():
+ for layer in self._layers:
+ x = layer(x)
+ if trace_powers:
+ powers.append(x.abs().square().sum())
+ if trace_powers:
+ return x, powers
+ return x
__all__ = [
diff --git a/src/single-core-regen/util/eye_diagram.py b/src/single-core-regen/util/eye_diagram.py
index b07047f..61115e1 100644
--- a/src/single-core-regen/util/eye_diagram.py
+++ b/src/single-core-regen/util/eye_diagram.py
@@ -3,6 +3,7 @@ from matplotlib.colors import LinearSegmentedColormap
import numpy as np
from scipy.cluster.vq import kmeans2
import warnings
+import multiprocessing
from rich.traceback import install
from rich import pretty
@@ -67,7 +68,7 @@ def generate_wavelet(sps, oversample=3):
class eye_diagram:
- def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4):
+ def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4, multithreaded=True):
# data has shape [channels, 2, samples]
# each sample has a timestamp and a value
if data.ndim == 2:
@@ -79,28 +80,38 @@ class eye_diagram:
self.eye_stats = [{"success": False} for _ in range(self.channels)]
self.horizontal_bins = horizontal_bins
self.vertical_bins = vertical_bins
+ self.multi_threaded = multithreaded
self.eye_built = False
- self.analyse(self.n_levels)
+ self.analyse()
def generate_eye_data(self):
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
self.y_bins = np.zeros((self.channels, self.vertical_bins))
self.eye_data = np.zeros((self.channels, self.vertical_bins, self.horizontal_bins))
- for i in range(self.channels):
- data_min = np.min(self.raw_data[i, 1, :])
- data_max = np.max(self.raw_data[i, 1, :])
- self.y_bins[i] = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
-
- t_vals = self.raw_data[i, 0, :] % 2
- val_vals = self.raw_data[i, 1, :]
-
- x_indices = np.digitize(t_vals, self.x_bins) - 1
- y_indices = np.digitize(val_vals, self.y_bins[i]) - 1
-
- np.add.at(self.eye_data[i], (y_indices, x_indices), 1)
+ datas = [self.raw_data[i] for i in range(self.channels)]
+ if self.multi_threaded:
+ with multiprocessing.Pool() as pool:
+ results = pool.map(self.generate_eye_data_single, datas)
+ for i, result in enumerate(results):
+ self.eye_data[i], self.y_bins[i] = result
+ else:
+ for i, data in enumerate(datas):
+ self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
self.eye_built = True
+
+ def generate_eye_data_single(self, data):
+ eye_data = np.zeros((self.vertical_bins, self.horizontal_bins))
+ data_min = np.min(data[1, :])
+ data_max = np.max(data[1, :])
+ y_bins = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
+ t_vals = data[0, :] % 2
+ val_vals = data[1, :]
+ x_indices = np.digitize(t_vals, self.x_bins) - 1
+ y_indices = np.digitize(val_vals, y_bins) - 1
+ np.add.at(eye_data, (y_indices, x_indices), 1)
+ return eye_data, y_bins
- def plot(self, title="Eye Diagram", stats=True, show=True):
+ def plot(self, title="Eye Diagram", stats=True, all_stats=True, show=True):
if not self.eye_built:
self.generate_eye_data()
cmap = LinearSegmentedColormap.from_list(
@@ -118,8 +129,10 @@ class eye_diagram:
ax = np.atleast_1d(ax).transpose().flatten()
for i in range(self.channels):
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}")
- ax[i].set_xlabel("Symbol")
- ax[i].set_ylabel("Amplitude")
+ if (i+1) % rows == 0:
+ ax[i].set_xlabel("Symbol")
+ if i < rows:
+ ax[i].set_ylabel("Amplitude")
ax[i].grid()
ax[i].imshow(
self.eye_data[i],
@@ -134,67 +147,6 @@ class eye_diagram:
yspan = ymax - ymin
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
if stats and self.eye_stats[i]["success"]:
- ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
- ax[i].set_yticks(self.eye_stats[i]["levels"])
- # add arrows for amplitudes
- for j in range(len(self.eye_stats[i]["amplitudes"])):
- ax[i].annotate(
- "",
- xy=(0.05, self.eye_stats[i]["levels"][j]),
- xytext=(0.05, self.eye_stats[i]["levels"][j + 1]),
- arrowprops=dict(arrowstyle="<->", facecolor="black"),
- )
- ax[i].annotate(
- f"{self.eye_stats[i]['amplitudes'][j]:.2e}",
- xy=(0.06, (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2),
- )
- # add arrows for eye heights
- for j in range(len(self.eye_stats[i]["heights"])):
- try:
- bot = np.max(self.eye_stats[i]["amplitude_clusters"][j])
- top = np.min(self.eye_stats[i]["amplitude_clusters"][j + 1])
-
- ax[i].annotate(
- "",
- xy=(self.eye_stats[i]["time_midpoint"], bot),
- xytext=(self.eye_stats[i]["time_midpoint"], top),
- arrowprops=dict(arrowstyle="<->", facecolor="black"),
- )
- ax[i].annotate(
- f"{self.eye_stats[i]['heights'][j]:.2e}",
- xy=(self.eye_stats[i]["time_midpoint"] + 0.015, (bot + top) / 2 + 0.04),
- )
- except (ValueError, IndexError):
- pass
- # add arrows for eye widths
- for j in range(len(self.eye_stats[i]["widths"])):
- try:
- left = np.max(self.eye_stats[i]["time_clusters"][j][0])
- right = np.min(self.eye_stats[i]["time_clusters"][j][1])
- vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
-
- ax[i].annotate(
- "",
- xy=(left, vertical),
- xytext=(right, vertical),
- arrowprops=dict(arrowstyle="<->", facecolor="black"),
- )
- ax[i].annotate(
- f"{self.eye_stats[i]['widths'][j]:.2e}",
- xy=((left + right) / 2 - 0.15, vertical + 0.01),
- )
- except (ValueError, IndexError):
- pass
-
- # add area
- for j in range(len(self.eye_stats[i]["areas"])):
- horizontal = self.eye_stats[i]["time_midpoint"]
- vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
- ax[i].annotate(
- f"{self.eye_stats[i]['areas'][j]:.2e}",
- xy=(horizontal + 0.035, vertical - 0.07),
- )
-
# add min_area above the plot
ax[i].annotate(
f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
@@ -202,62 +154,142 @@ class eye_diagram:
# xycoords="axes fraction",
ha="left",
va="center",
+ bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
)
+
+ if all_stats:
+ ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
+ ax[i].set_yticks(self.eye_stats[i]["levels"])
+ # add arrows for amplitudes
+ for j in range(len(self.eye_stats[i]["amplitudes"])):
+ ax[i].annotate(
+ "",
+ xy=(0.05, self.eye_stats[i]["levels"][j]),
+ xytext=(0.05, self.eye_stats[i]["levels"][j + 1]),
+ arrowprops=dict(arrowstyle="<->", facecolor="black"),
+ )
+ ax[i].annotate(
+ f"{self.eye_stats[i]['amplitudes'][j]:.2e}",
+ xy=(0.06, (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2),
+ bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
+ )
+ # add arrows for eye heights
+ for j in range(len(self.eye_stats[i]["heights"])):
+ try:
+ bot = np.max(self.eye_stats[i]["amplitude_clusters"][j])
+ top = np.min(self.eye_stats[i]["amplitude_clusters"][j + 1])
+
+ ax[i].annotate(
+ "",
+ xy=(self.eye_stats[i]["time_midpoint"], bot),
+ xytext=(self.eye_stats[i]["time_midpoint"], top),
+ arrowprops=dict(arrowstyle="<->", facecolor="black"),
+ )
+ ax[i].annotate(
+ f"{self.eye_stats[i]['heights'][j]:.2e}",
+ xy=(self.eye_stats[i]["time_midpoint"] + 0.015, (bot + top) / 2 + 0.04),
+ bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
+ )
+ except (ValueError, IndexError):
+ pass
+ # add arrows for eye widths
+ for j in range(len(self.eye_stats[i]["widths"])):
+ try:
+ left = np.max(self.eye_stats[i]["time_clusters"][j][0])
+ right = np.min(self.eye_stats[i]["time_clusters"][j][1])
+ vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
+
+ ax[i].annotate(
+ "",
+ xy=(left, vertical),
+ xytext=(right, vertical),
+ arrowprops=dict(arrowstyle="<->", facecolor="black"),
+ )
+ ax[i].annotate(
+ f"{self.eye_stats[i]['widths'][j]:.2e}",
+ xy=((left + right) / 2 - 0.15, vertical + 0.01),
+ bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
+ )
+ except (ValueError, IndexError):
+ pass
+
+ # add area
+ for j in range(len(self.eye_stats[i]["areas"])):
+ horizontal = self.eye_stats[i]["time_midpoint"]
+ vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
+ ax[i].annotate(
+ f"{self.eye_stats[i]['areas'][j]:.2e}",
+ xy=(horizontal + 0.035, vertical - 0.07),
+ bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
+ )
+
fig.tight_layout()
if show:
plt.show()
return fig
-
- def analyse(self, n_levels=4):
+
+ def analyse_single(self, data, index):
warnings.filterwarnings("error")
- for i in range(self.channels):
- self.eye_stats[i]["channel"] = str(i+1) if self.channel_names is None else self.channel_names[i]
- try:
- approx_levels = eye_diagram.approximate_levels(self.raw_data[i], n_levels)
+ eye_stats = {}
+ eye_stats["channel_name"] = str(index+1) if self.channel_names is None else self.channel_names[index]
+ try:
+ approx_levels = eye_diagram.approximate_levels(data, self.n_levels)
- time_bounds = eye_diagram.calculate_time_bounds(self.raw_data[i], approx_levels)
+ time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels)
- self.eye_stats[i]["time_midpoint"] = (time_bounds[0] + time_bounds[1]) / 2
+ eye_stats["time_midpoint"] = (time_bounds[0] + time_bounds[1]) / 2
- self.eye_stats[i]["levels"], self.eye_stats[i]["amplitude_clusters"] = eye_diagram.calculate_levels(
- self.raw_data[i], approx_levels, time_bounds
- )
+ eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels(
+ data, approx_levels, time_bounds
+ )
- self.eye_stats[i]["amplitudes"] = np.diff(self.eye_stats[i]["levels"])
+ eye_stats["amplitudes"] = np.diff(eye_stats["levels"])
- self.eye_stats[i]["heights"] = eye_diagram.calculate_eye_heights(
- self.eye_stats[i]["amplitude_clusters"]
- )
+ eye_stats["heights"] = eye_diagram.calculate_eye_heights(
+ eye_stats["amplitude_clusters"]
+ )
- self.eye_stats[i]["widths"], self.eye_stats[i]["time_clusters"] = eye_diagram.calculate_eye_widths(
- self.raw_data[i], self.eye_stats[i]["levels"]
- )
+ eye_stats["widths"], eye_stats["time_clusters"] = eye_diagram.calculate_eye_widths(
+ data, eye_stats["levels"]
+ )
- # # check if time clusters are valid (upper bound > time_midpoint > lower bound)
- # # if not: raise ValueError
- # for j in range(len(self.eye_stats[i]['time_clusters'])):
- # if not (np.max(self.eye_stats[i]['time_clusters'][j][0]) < self.eye_stats[i]["time_midpoint"] < np.min(self.eye_stats[i]['time_clusters'][j][1])):
- # raise ValueError
+ # # check if time clusters are valid (upper bound > time_midpoint > lower bound)
+ # # if not: raise ValueError
+ # for j in range(len(eye_stats['time_clusters'])):
+ # if not (np.max(eye_stats['time_clusters'][j][0]) < eye_stats["time_midpoint"] < np.min(eye_stats['time_clusters'][j][1])):
+ # raise ValueError
- self.eye_stats[i]["areas"] = self.eye_stats[i]["heights"] * self.eye_stats[i]["widths"]
- self.eye_stats[i]["mean_area"] = np.mean(self.eye_stats[i]["areas"])
- self.eye_stats[i]["min_area"] = np.min(self.eye_stats[i]["areas"])
-
- self.eye_stats[i]["success"] = True
- except (RuntimeWarning, UserWarning, ValueError):
- self.eye_stats[i]["success"] = False
- self.eye_stats[i]["time_midpoint"] = 0
- self.eye_stats[i]["levels"] = np.zeros(n_levels)
- self.eye_stats[i]["amplitude_clusters"] = []
- self.eye_stats[i]["amplitudes"] = np.zeros(n_levels - 1)
- self.eye_stats[i]["heights"] = np.zeros(n_levels - 1)
- self.eye_stats[i]["widths"] = np.zeros(n_levels - 1)
- self.eye_stats[i]["areas"] = np.zeros(n_levels - 1)
- self.eye_stats[i]["mean_area"] = 0
- self.eye_stats[i]["min_area"] = 0
+ eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"]
+ eye_stats["mean_area"] = np.mean(eye_stats["areas"])
+ eye_stats["min_area"] = np.min(eye_stats["areas"])
+ eye_stats["success"] = True
+ except (RuntimeWarning, UserWarning, ValueError):
+ eye_stats["success"] = False
+ eye_stats["time_midpoint"] = 0
+ eye_stats["levels"] = np.zeros(self.n_levels)
+ eye_stats["amplitude_clusters"] = []
+ eye_stats["amplitudes"] = np.zeros(self.n_levels - 1)
+ eye_stats["heights"] = np.zeros(self.n_levels - 1)
+ eye_stats["widths"] = np.zeros(self.n_levels - 1)
+ eye_stats["areas"] = np.zeros(self.n_levels - 1)
+ eye_stats["mean_area"] = 0
+ eye_stats["min_area"] = 0
warnings.resetwarnings()
+ return eye_stats
+
+
+ def analyse(self):
+ self.eye_stats = []
+ if self.multi_threaded:
+ with multiprocessing.Pool() as pool:
+ results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.channels)])
+ for i, result in enumerate(results):
+ self.eye_stats.append(result)
+ else:
+ for i in range(self.channels):
+ self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
@staticmethod
def approximate_levels(data, levels):
diff --git a/src/visualization/viz.ipynb b/src/visualization/viz.ipynb
new file mode 100644
index 0000000..4737357
--- /dev/null
+++ b/src/visualization/viz.ipynb
@@ -0,0 +1,82160 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'dot'"
+ ]
+ },
+ "execution_count": 92,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import graphviz\n",
+ "graphviz.set_default_engine('neato')\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 332,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 332,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from typing import Literal\n",
+ "\n",
+ "\n",
+ "def clements(N, prefix=None, offset=(0, 0)):\n",
+ " g = graphviz.Graph()\n",
+ "\n",
+ " if prefix is None:\n",
+ " prefix = \"\"\n",
+ " elif not isinstance(prefix, str):\n",
+ " prefix = f\"{prefix}_\"\n",
+ " else:\n",
+ " prefix = f\"{prefix}_\" if len(prefix) else \"\"\n",
+ "\n",
+ " order = list(range(N))\n",
+ " connections = [[] for _ in range(N)]\n",
+ "\n",
+ " common_node_kwargs = dict(\n",
+ " label=\"\",\n",
+ " shape=\"diamond\",\n",
+ " fixedsize=\"true\",\n",
+ " width=\"0.6\",\n",
+ " height=\"0.6\",\n",
+ " fillcolor=\"gray85\",\n",
+ " pin=\"true\",\n",
+ " z=\"0\",\n",
+ " style=\"filled\",\n",
+ " )\n",
+ " none_kwargs = dict(shape=\"none\", fillcolor=\"none\")\n",
+ "\n",
+ " ranges = [None, None, None, None]\n",
+ "\n",
+ " def update_ranges(x, y):\n",
+ " # ranges: min_x, max_x, min_y, max_y\n",
+ " if ranges[0] is None or x < ranges[0]:\n",
+ " ranges[0] = x\n",
+ " if ranges[1] is None or x > ranges[1]:\n",
+ " ranges[1] = x\n",
+ " if ranges[2] is None or y < ranges[2]:\n",
+ " ranges[2] = y\n",
+ " if ranges[3] is None or y > ranges[3]:\n",
+ " ranges[3] = y\n",
+ "\n",
+ " def add_node(x, y, name, *connections_indices, **node_kwargs):\n",
+ " update_ranges(x, y)\n",
+ " name = f\"{prefix}{name}\"\n",
+ " kwargs = common_node_kwargs.copy()\n",
+ " kwargs.update(node_kwargs)\n",
+ " g.node(name, pos=f\"{x},{y}\", **kwargs)\n",
+ " for i in connections_indices:\n",
+ " connections[i].append((name, \"false\" if kwargs.get(\"shape\", \"none\") == \"none\" else \"true\"))\n",
+ "\n",
+ " current_x = 0\n",
+ "\n",
+ " for i in range(N):\n",
+ " add_node(offset[0], -i + offset[1], f\"in_{i}\", i, **none_kwargs)\n",
+ " current_x += 0.5\n",
+ "\n",
+ " for i in range(N + 1):\n",
+ " for j in range(N):\n",
+ " add_node(current_x + offset[0], -j + offset[1], f\"{i}_{j}\", order[j], **none_kwargs)\n",
+ " if 0 < i < N:\n",
+ " add_node(current_x + offset[0] + 0.5, -j + offset[1], f\"{i}__{j}\", order[j], **none_kwargs)\n",
+ " current_x += 0.5\n",
+ " if 0 < i < N:\n",
+ " current_x += 0.5\n",
+ " if i > N - 1:\n",
+ " break\n",
+ " for j in range(i % 2, N + i % 2, 2):\n",
+ " try:\n",
+ " signal_in1 = order[j]\n",
+ " signal_in2 = order[j + 1]\n",
+ " name = f\"x_{signal_in1}_{signal_in2}\"\n",
+ " add_node(current_x + offset[0], -j - 0.5 + offset[1], name, signal_in1, signal_in2)\n",
+ " order[j], order[j + 1] = order[j + 1], order[j]\n",
+ " except IndexError:\n",
+ " pass\n",
+ " \n",
+ " if N <= 2:\n",
+ " current_x += 0.5\n",
+ " for j in range(N):\n",
+ " add_node(current_x + offset[0], -j + offset[1], f\"{i+1}_{j}\", order[j], **none_kwargs)\n",
+ " current_x += 0.5\n",
+ " break\n",
+ "\n",
+ " if N % 2 == 0:\n",
+ " if i % 2:\n",
+ " add_node(current_x + offset[0], 0.5 + offset[1], f\"up_{i}\", order[0], **none_kwargs)\n",
+ " add_node(current_x + offset[0], -N + 0.5 + offset[1], f\"down_{i}\", order[-1], **none_kwargs)\n",
+ " else:\n",
+ " if i % 2 == 0:\n",
+ " add_node(current_x + offset[0], -N + 0.5 + offset[1], f\"down_{i}\", order[-1], **none_kwargs)\n",
+ " else:\n",
+ " add_node(current_x + offset[0], 0.5 + offset[1], f\"up_{i}\", order[0], **none_kwargs)\n",
+ "\n",
+ " current_x += 0.5\n",
+ "\n",
+ " for i in range(N):\n",
+ " # add_node(N + 0.5 + offset[0], -i + offset[1], f\"out_{i}\", order[i], **none_kwargs)\n",
+ " add_node(current_x + offset[0], -i + offset[1], f\"out_{i}\", order[i], **none_kwargs)\n",
+ "\n",
+ " for i in range(N):\n",
+ " for a, b in zip(connections[i][:-1], connections[i][1:]):\n",
+ " g.edge(a[0], b[0], headclip=b[1], tailclip=a[1], arrowhead=\"none\")\n",
+ "\n",
+ " g.range = ranges\n",
+ " g.top_coord = (current_x + offset[0], offset[1])\n",
+ " g.n_nodes = int(N * (N - 1) // 2)\n",
+ "\n",
+ " return g\n",
+ "\n",
+ "\n",
+ "def activation_layer(N, N_out=None, prefix=None, offset=(0, 0), strat: Literal[\"top\", \"center\", \"bottom\"] = \"center\"):\n",
+ " g = graphviz.Graph()\n",
+ "\n",
+ " if N_out is None:\n",
+ " N_out = N\n",
+ "\n",
+ " assert N_out <= N\n",
+ "\n",
+ " if prefix is None:\n",
+ " prefix = \"\"\n",
+ " elif not isinstance(prefix, str):\n",
+ " prefix = f\"{prefix}_\"\n",
+ " else:\n",
+ " prefix = f\"{prefix}_\" if len(prefix) else \"\"\n",
+ "\n",
+ " act_width = 0.8\n",
+ " term_width = 0.5\n",
+ "\n",
+ " term_offset = (act_width - term_width) / 2\n",
+ "\n",
+ " act_node_kwargs = dict(\n",
+ " label=\"\",\n",
+ " shape=\"cds\",\n",
+ " fixedsize=\"true\",\n",
+ " width=\"0.8\",\n",
+ " height=\"0.6\",\n",
+ " fillcolor=\"gray85\",\n",
+ " pin=\"true\",\n",
+ " z=\"0\",\n",
+ " style=\"filled\",\n",
+ " )\n",
+ "\n",
+ " term_node_kwargs = dict(\n",
+ " label=\"\",\n",
+ " shape=\"square\",\n",
+ " fixedsize=\"true\",\n",
+ " width=\"0.5\",\n",
+ " fillcolor=\"gray85\",\n",
+ " pin=\"true\",\n",
+ " z=\"0\",\n",
+ " style=\"filled\",\n",
+ " )\n",
+ "\n",
+ " none_node_kwargs = dict(label=\"\", shape=\"none\", fillcolor=\"none\", pin=\"true\", z=\"0\")\n",
+ "\n",
+ " for i in range(N):\n",
+ " g.node(f\"{prefix}in_{i}\", pos=f\"{offset[0]},{-i + offset[1]}\", **none_node_kwargs)\n",
+ "\n",
+ "\n",
+ "\n",
+ " if strat == \"top\":\n",
+ " for out, num in enumerate(range(N_out)):\n",
+ " if out == 0:\n",
+ " y_out = -num + offset[1]\n",
+ " g.node(f\"{prefix}act_{out}\", pos=f\"{offset[0]+0.5},{-num + offset[1]}\", **act_node_kwargs)\n",
+ " g.node(f\"{prefix}out_{out}\", pos=f\"{offset[0]+1},{-num + offset[1]}\", **none_node_kwargs)\n",
+ " g.edge(f\"{prefix}in_{num}\", f\"{prefix}act_{out}\", tailclip=\"false\", headclip=\"true\", arrowhead=\"none\")\n",
+ " g.edge(f\"{prefix}act_{out}\", f\"{prefix}out_{out}\", tailclip=\"true\", headclip=\"false\", arrowhead=\"none\")\n",
+ " for num in range(N_out, N):\n",
+ " g.node(f\"{prefix}term_{num}\", pos=f\"{offset[0]+0.5 - term_offset},{-num + offset[1]}\", **term_node_kwargs)\n",
+ " g.edge(f\"{prefix}in_{num}\", f\"{prefix}term_{num}\", tailclip=\"false\", headclip=\"true\", arrowhead=\"none\")\n",
+ "\n",
+ " elif strat == \"center\":\n",
+ " for out, num in enumerate(range((N - N_out) // 2, N_out + (N - N_out) // 2)):\n",
+ " if out == 0:\n",
+ " y_out = -num + offset[1]\n",
+ " g.node(f\"{prefix}act_{out}\", pos=f\"{offset[0]+0.5},{-num + offset[1]}\", **act_node_kwargs)\n",
+ " g.node(f\"{prefix}out_{out}\", pos=f\"{offset[0]+1},{-num + offset[1]}\", **none_node_kwargs)\n",
+ " g.edge(f\"{prefix}in_{num}\", f\"{prefix}act_{out}\", tailclip=\"false\", headclip=\"true\", arrowhead=\"none\")\n",
+ " g.edge(f\"{prefix}act_{out}\", f\"{prefix}out_{out}\", tailclip=\"true\", headclip=\"false\", arrowhead=\"none\")\n",
+ " for num in range((N - N_out) // 2):\n",
+ " g.node(f\"{prefix}term_{num}\", pos=f\"{offset[0]+0.5 - term_offset},{-num + offset[1]}\", **term_node_kwargs)\n",
+ " g.edge(f\"{prefix}in_{num}\", f\"{prefix}term_{num}\", tailclip=\"false\", headclip=\"true\", arrowhead=\"none\")\n",
+ " for num in range(N_out + (N - N_out) // 2, N):\n",
+ " g.node(f\"{prefix}term_{num}\", pos=f\"{offset[0]+0.5 - term_offset},{-num + offset[1]}\", **term_node_kwargs)\n",
+ " g.edge(f\"{prefix}in_{num}\", f\"{prefix}term_{num}\", tailclip=\"false\", headclip=\"true\", arrowhead=\"none\")\n",
+ "\n",
+ " elif strat == \"bottom\":\n",
+ " for out, num in enumerate(range(N - N_out, N)):\n",
+ " if out == 0:\n",
+ " y_out = -num + offset[1]\n",
+ " g.node(f\"{prefix}act_{out}\", pos=f\"{offset[0]+0.5},{-num + offset[1]}\", **act_node_kwargs)\n",
+ " g.node(f\"{prefix}out_{out}\", pos=f\"{offset[0]+1},{-num + offset[1]}\", **none_node_kwargs)\n",
+ " g.edge(f\"{prefix}in_{num}\", f\"{prefix}act_{out}\", tailclip=\"false\", headclip=\"true\", arrowhead=\"none\")\n",
+ " g.edge(f\"{prefix}act_{out}\", f\"{prefix}out_{out}\", tailclip=\"true\", headclip=\"false\", arrowhead=\"none\")\n",
+ " for num in range(N - N_out):\n",
+ " g.node(f\"{prefix}term_{num}\", pos=f\"{offset[0]+0.5 - term_offset},{-num + offset[1]}\", **term_node_kwargs)\n",
+ " g.edge(f\"{prefix}in_{num}\", f\"{prefix}term_{num}\", tailclip=\"false\", headclip=\"true\", arrowhead=\"none\")\n",
+ "\n",
+ "\n",
+ " else:\n",
+ " raise ValueError(f\"Invalid strat: {strat}\")\n",
+ "\n",
+ " g.range = [offset[0], offset[0]+1, - N + offset[1], offset[1]]\n",
+ " g.top_coord = (offset[0] + 1, y_out)\n",
+ "\n",
+ " return g\n",
+ "\n",
+ "\n",
+ "g = clements(3)\n",
+ "# print(g.range)\n",
+ "g\n",
+ "\n",
+ "# a = activation_layer(4, 2, strat=\"bottom\")\n",
+ "# a\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 334,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1395\n"
+ ]
+ },
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 334,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "p = graphviz.Graph()\n",
+ "\n",
+ "\n",
+ "def add_subgraph(p, c):\n",
+ " p.subgraphs = getattr(p, \"subgraphs\", [])\n",
+ " p.subgraphs.append(c)\n",
+ " p.subgraph(c)\n",
+ "\n",
+ "\n",
+ "input_size = 26*2\n",
+ "hidden_sizes = [8,8,4,4]\n",
+ "output_size = 2\n",
+ "\n",
+ "sizes = (input_size, *hidden_sizes, output_size)\n",
+ "\n",
+ "for i, h in enumerate(sizes):\n",
+ " if i == 0:\n",
+ " add_subgraph(p, clements(h, prefix=i))\n",
+ " else:\n",
+ " add_subgraph(\n",
+ " p, activation_layer(sizes[i-1], h, prefix=f\"act_{i}\", offset=(p.subgraphs[-1].top_coord[0], p.subgraphs[-1].top_coord[1]), strat=\"center\")\n",
+ " )\n",
+ " add_subgraph(p, clements(h, prefix=i, offset=(p.subgraphs[-1].top_coord[0], p.subgraphs[-1].top_coord[1])))\n",
+ "\n",
+ "n_nodes = sum([g.n_nodes for g in p.subgraphs if hasattr(g, \"n_nodes\")])\n",
+ "\n",
+ "print(n_nodes)\n",
+ "\n",
+ "p\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 335,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'out.svg'"
+ ]
+ },
+ "execution_count": 335,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "p.format='svg'\n",
+ "p.render('out')\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}