add regenerator class and update dataset configurations for model training

This commit is contained in:
Joseph Hopfmüller
2024-12-05 23:55:03 +01:00
parent 884d9f73c9
commit 0e29b87395
7 changed files with 82705 additions and 353 deletions

View File

@@ -1,7 +1,16 @@
import copy import copy
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
import random
from typing import Literal from typing import Literal
import matplotlib
from matplotlib.colors import LinearSegmentedColormap
import torch.nn.utils.parametrize
try:
matplotlib.use("cairo")
except ImportError:
matplotlib.use("Agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@@ -11,7 +20,7 @@ import optuna
import warnings import warnings
import torch import torch
import torch.nn as nn # import torch.nn as nn
# import torch.nn.functional as F # mse_loss doesn't support complex numbers # import torch.nn.functional as F # mse_loss doesn't support complex numbers
import torch.optim as optim import torch.optim as optim
@@ -19,27 +28,9 @@ import torch.utils.data
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
# from rich.progress import (
# Progress,
# TextColumn,
# BarColumn,
# TaskProgressColumn,
# TimeRemainingColumn,
# MofNCompleteColumn,
# TimeElapsedColumn,
# )
# from rich.console import Console
# from rich import print as rprint
import multiprocessing import multiprocessing
from util.datasets import FiberRegenerationDataset from util.datasets import FiberRegenerationDataset
# from util.optuna_helpers import (
# suggest_categorical_optional, # noqa: F401
# suggest_float_optional, # noqa: F401
# suggest_int_optional, # noqa: F401
# )
from util.optuna_helpers import install_optional_suggests from util.optuna_helpers import install_optional_suggests
import util import util
@@ -65,7 +56,6 @@ class HyperTraining:
model_settings, model_settings,
optimizer_settings, optimizer_settings,
optuna_settings, optuna_settings,
# console=None,
): ):
self.global_settings: GlobalSettings = global_settings self.global_settings: GlobalSettings = global_settings
self.data_settings: DataSettings = data_settings self.data_settings: DataSettings = data_settings
@@ -75,11 +65,8 @@ class HyperTraining:
self.optuna_settings: OptunaSettings = optuna_settings self.optuna_settings: OptunaSettings = optuna_settings
self.processes = None self.processes = None
# self.console = console or Console()
# set some extra settings to make the code more readable
self._extra_optuna_settings() self._extra_optuna_settings()
self.stop_study = True self.stop_study = False
def setup_tb_writer(self, study_name=None, append=None): def setup_tb_writer(self, study_name=None, append=None):
log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name) log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name)
@@ -229,7 +216,7 @@ class HyperTraining:
self.optuna_settings._parallel = self.optuna_settings._n_threads > 1 self.optuna_settings._parallel = self.optuna_settings._n_threads > 1
def define_model(self, trial: optuna.Trial, writer=None): def define_model(self, trial: optuna.Trial, writer=None):
n_layers = trial.suggest_int_optional("model_n_hidden_layers", self.model_settings.n_hidden_layers) n_hidden_layers = trial.suggest_int_optional("model_n_hidden_layers", self.model_settings.n_hidden_layers)
input_dim = trial.suggest_int_optional( input_dim = trial.suggest_int_optional(
"model_input_dim", "model_input_dim",
@@ -245,32 +232,41 @@ class HyperTraining:
dtype = getattr(torch, dtype) dtype = getattr(torch, dtype)
afunc = trial.suggest_categorical_optional("model_activation_func", self.model_settings.model_activation_func) afunc = trial.suggest_categorical_optional("model_activation_func", self.model_settings.model_activation_func)
# T0 = trial.suggest_float_optional("T0", self.model_settings.satabsT0 , log=True) afunc = getattr(util.complexNN, afunc)
layer_func = trial.suggest_categorical_optional("model_layer_function", self.model_settings.model_layer_function)
layer_func = getattr(util.complexNN, layer_func)
layer_parametrizations = self.model_settings.model_layer_parametrizations
layers = [] scale_layers = trial.suggest_categorical_optional("model_enable_scale_layers", self.model_settings.scale)
last_dim = input_dim
n_nodes = last_dim
for i in range(n_layers): hidden_dims = []
for i in range(n_hidden_layers):
if hidden_dim_override := self.model_settings.overrides.get(f"n_hidden_nodes_{i}", False): if hidden_dim_override := self.model_settings.overrides.get(f"n_hidden_nodes_{i}", False):
hidden_dim = trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override) hidden_dims.append(trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override))
else: else:
hidden_dim = trial.suggest_int_optional( hidden_dims.append(trial.suggest_int_optional(
f"model_hidden_dim_{i}", f"model_hidden_dim_{i}",
self.model_settings.n_hidden_nodes, self.model_settings.n_hidden_nodes,
) ))
layers.append(util.complexNN.ONNRect(last_dim, hidden_dim, dtype=dtype))
last_dim = hidden_dim model_kwargs = {
layers.append(getattr(util.complexNN, afunc)()) "dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
n_nodes += last_dim "layer_function": layer_func,
"layer_parametrizations": layer_parametrizations,
layers.append(util.complexNN.ONNRect(last_dim, self.model_settings.output_dim, dtype=dtype)) "activation_function": afunc,
"dtype": dtype,
model = nn.Sequential(*layers) "droupout_prob": self.model_settings.dropout_prob,
"scale": scale_layers,
}
model = util.complexNN.regenerator(**model_kwargs)
n_nodes = sum(hidden_dims)
if writer is not None: if writer is not None:
writer.add_graph(model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False) writer.add_graph(model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False)
n_params = sum(p.numel() for p in model.parameters()) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
trial.set_user_attr("model_n_params", n_params) trial.set_user_attr("model_n_params", n_params)
trial.set_user_attr("model_n_nodes", n_nodes) trial.set_user_attr("model_n_nodes", n_nodes)
@@ -384,7 +380,8 @@ class HyperTraining:
running_loss2 = 0.0 running_loss2 = 0.0
running_loss = 0.0 running_loss = 0.0
model.train() model.train()
for batch_idx, (x, y) in enumerate(train_loader): loader_len = len(train_loader)
for batch_idx, (x, y, _) in enumerate(train_loader):
if batch_idx >= self.optuna_settings._n_train_batches: if batch_idx >= self.optuna_settings._n_train_batches:
break break
model.zero_grad(set_to_none=True) model.zero_grad(set_to_none=True)
@@ -408,14 +405,14 @@ class HyperTraining:
writer.add_scalar( writer.add_scalar(
"training loss", "training loss",
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1), running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
epoch * min(len(train_loader), self.optuna_settings._n_train_batches) + batch_idx, epoch * min(loader_len, self.optuna_settings._n_train_batches) + batch_idx,
) )
running_loss2 = 0.0 running_loss2 = 0.0
# if enable_progress: # if enable_progress:
# progress.stop() # progress.stop()
return running_loss / min(len(train_loader), self.optuna_settings._n_train_batches) return running_loss / min(loader_len, self.optuna_settings._n_train_batches)
def eval_model( def eval_model(
self, self,
@@ -446,9 +443,8 @@ class HyperTraining:
model.eval() model.eval()
running_error = 0 running_error = 0
running_error_2 = 0
with torch.no_grad(): with torch.no_grad():
for batch_idx, (x, y) in enumerate(valid_loader): for batch_idx, (x, y, _) in enumerate(valid_loader):
if batch_idx >= self.optuna_settings._n_valid_batches: if batch_idx >= self.optuna_settings._n_valid_batches:
break break
x, y = ( x, y = (
@@ -459,19 +455,6 @@ class HyperTraining:
error = util.complexNN.complex_mse_loss(y_pred, y) error = util.complexNN.complex_mse_loss(y_pred, y)
error_value = error.item() error_value = error.item()
running_error += error_value running_error += error_value
running_error_2 += error_value
# if enable_progress:
# progress.update(task, advance=1, description=f"{error_value:.3e}")
if writer is not None:
if batch_idx % self.pytorch_settings.write_every == 0:
writer.add_scalar(
"eval loss",
running_error_2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
epoch * min(len(valid_loader), self.optuna_settings._n_valid_batches) + batch_idx,
)
running_error_2 = 0.0
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches) running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
@@ -488,38 +471,73 @@ class HyperTraining:
), ),
epoch + 1, epoch + 1,
) )
writer.add_figure(
"eye diagram",
self.plot_model_response(
trial,
model=self.model,
title_append=title_append,
subtitle=subtitle,
show=False,
mode="eye",
),
epoch + 1,
)
writer.add_figure(
"powers",
self.plot_model_response(
trial,
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="powers",
show=False,
),
epoch + 1,
)
# if enable_progress: # if enable_progress:
# progress.stop() # progress.stop()
return running_error return running_error
def run_model(self, model, loader): def run_model(self, model, loader, trace_powers=False):
model.eval() model.eval()
xs = [] fiber_out = []
ys = [] fiber_in = []
y_preds = [] regen = []
timestamps = []
with torch.no_grad(): with torch.no_grad():
model = model.to(self.pytorch_settings.device) model = model.to(self.pytorch_settings.device)
for x, y in loader: for x, y, timestamp in loader:
x, y = ( x, y = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
) )
y_pred = model(x).cpu() if trace_powers:
y_pred, powers = model(x, trace_powers).cpu()
else:
y_pred = model(x, trace_powers).cpu()
# x = x.cpu() # x = x.cpu()
# y = y.cpu() # y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2) y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2) y = y.view(y.shape[0], -1, 2)
x = x.view(x.shape[0], -1, 2) x = x.view(x.shape[0], -1, 2)
xs.append(x[:, 0, :].squeeze()) # timestamp = timestamp.view(-1, 1)
ys.append(y.squeeze()) fiber_out.append(x[:, x.shape[1] // 2, :].squeeze())
y_preds.append(y_pred.squeeze()) fiber_in.append(y.squeeze())
regen.append(y_pred.squeeze())
timestamps.append(timestamp.squeeze())
xs = torch.vstack(xs).cpu() fiber_out = torch.vstack(fiber_out).cpu()
ys = torch.vstack(ys).cpu() fiber_in = torch.vstack(fiber_in).cpu()
y_preds = torch.vstack(y_preds).cpu() regen = torch.vstack(regen).cpu()
return ys, xs, y_preds timestamps = torch.concat(timestamps).cpu()
if trace_powers:
return fiber_in, fiber_out, regen, timestamps, powers
return fiber_in, fiber_out, regen, timestamps
def objective(self, trial: optuna.Trial, plot_before=False): def objective(self, trial: optuna.Trial, plot_before=False):
if self.stop_study: if self.stop_study:
@@ -544,7 +562,32 @@ class HyperTraining:
model=model, model=model,
title_append=title_append, title_append=title_append,
subtitle=subtitle, subtitle=subtitle,
show=plot_before, show=False,
),
0,
)
writer.add_figure(
"eye diagram",
self.plot_model_response(
trial,
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="eye",
show=False,
),
0,
)
writer.add_figure(
"powers",
self.plot_model_response(
trial,
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="powers",
show=False,
), ),
0, 0,
) )
@@ -609,7 +652,9 @@ class HyperTraining:
return error return error
def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True): def _plot_model_response_eye(
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
):
if sps is None: if sps is None:
raise ValueError("sps must be provided") raise ValueError("sps must be provided")
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
@@ -624,27 +669,84 @@ class HyperTraining:
if not any(labels): if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))] labels = [f"signal {i + 1}" for i in range(len(signals))]
fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True) x_bins = np.linspace(0, 2, 2 * sps, endpoint=False)
y_bins = np.zeros((2 * len(signals), 1000))
eye_data = np.zeros((2 * len(signals), 1000, 2 * sps))
# signals = [signal.cpu().numpy() for signal in signals]
for i in range(len(signals) * 2):
eye_signal = signals[i // 2][:, i % 2] # x, y, x, y, ...
eye_signal = np.real(np.square(np.abs(eye_signal)))
data_min = np.min(eye_signal)
data_max = np.max(eye_signal)
y_bins[i] = np.linspace(data_min, data_max, 1000, endpoint=False)
for j in range(len(timestamps)):
t = timestamps[j] / sps
val = eye_signal[j]
x = np.digitize(t % 2, x_bins) - 1
y = np.digitize(val, y_bins[i]) - 1
eye_data[i][y][x] += 1
cmap = LinearSegmentedColormap.from_list(
"eyemap",
[
(0, "white"),
(0.001, "dodgerblue"),
(0.1, "blue"),
(0.2, "cyan"),
(0.5, "lime"),
(0.8, "gold"),
(1, "red"),
],
)
# ordering = np.argsort(timestamps)
# signals = [signal[ordering] for signal in signals]
# timestamps = timestamps[ordering]
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
fig.set_figwidth(18)
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}") fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False) # xaxis = timestamps / sps
for j, (label, signal) in enumerate(zip(labels, signals)): # xaxis = np.arange(2 * sps) / sps
for j, label in enumerate(labels):
x = eye_data[2 * j]
y = eye_data[2 * j + 1]
# x, y = signal.T
# signal = signal.cpu().numpy() # signal = signal.cpu().numpy()
for i in range(len(signal) // sps - 1): # for i in range(len(signal) // sps - 1):
x, y = signal[i * sps : (i + 2) * sps].T # x, y = signal[i * sps : (i + 2) * sps].T
axs[0, j].plot(xaxis, np.abs(x) ** 2, color="C0", alpha=0.02) # axs[0 + 2 * j].scatter((timestamps/sps) % 2, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1)
axs[1, j].plot(xaxis, np.abs(y) ** 2, color="C0", alpha=0.02) # axs[1 + 2 * j].scatter((timestamps/sps) % 2, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1)
axs[0, j].set_title(label + " x") axs[0 + 2 * j].imshow(
axs[1, j].set_title(label + " y") x, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j][0], y_bins[2 * j][-1]]
axs[0, j].set_xlabel("Symbol") )
axs[1, j].set_xlabel("Symbol") axs[1 + 2 * j].imshow(
axs[0, j].set_ylabel("normalized power") y, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j + 1][0], y_bins[2 * j + 1][-1]]
axs[1, j].set_ylabel("normalized power") )
axs[0 + 2 * j].set_xlim((x_bins[0], x_bins[-1]))
axs[1 + 2 * j].set_xlim((x_bins[0], x_bins[-1]))
ymin = np.min(y_bins[:, 0])
ymax = np.max(y_bins[:, -1])
ydiff = ymax - ymin
axs[0 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff))
axs[1 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff))
axs[0 + 2 * j].set_title(label + " x")
axs[1 + 2 * j].set_title(label + " y")
axs[0 + 2 * j].set_xlabel("Symbol")
axs[1 + 2 * j].set_xlabel("Symbol")
axs[0 + 2 * j].set_box_aspect(1)
axs[1 + 2 * j].set_box_aspect(1)
axs[0].set_ylabel("normalized power")
fig.tight_layout()
# axs[1+2*len(labels)-1].set_ylabel("normalized power")
if show: if show:
plt.show() plt.show()
return fig return fig
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True): def _plot_model_response_head(
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
):
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
labels = [labels] labels = [labels]
else: else:
@@ -657,19 +759,31 @@ class HyperTraining:
if not any(labels): if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))] labels = [f"signal {i + 1}" for i in range(len(signals))]
ordering = np.argsort(timestamps)
signals = [signal[ordering] for signal in signals]
timestamps = timestamps[ordering]
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True) fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
fig.set_size_inches(18, 6) fig.set_figwidth(18)
fig.set_figheight(4)
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}") fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
for i, ax in enumerate(axs): for i, ax in enumerate(axs):
ax: plt.Axes
for signal, label in zip(signals, labels): for signal, label in zip(signals, labels):
if sps is not None: if sps is not None:
xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False) xaxis = timestamps / sps
else: else:
xaxis = np.arange(len(signal)) xaxis = timestamps
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label) ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
ax.set_xlabel("Sample" if sps is None else "Symbol") ax.set_xlabel("Sample" if sps is None else "Symbol")
ax.set_ylabel("normalized power") ax.set_ylabel("normalized power")
ax.minorticks_on()
ax.tick_params(axis="y", which="minor", left=False, right=False)
ax.grid(which="major", axis="x")
ax.grid(which="minor", axis="x", linestyle=":")
ax.grid(which="major", axis="y")
ax.legend(loc="upper right") ax.legend(loc="upper right")
fig.tight_layout()
if show: if show:
plt.show() plt.show()
return fig return fig
@@ -680,22 +794,52 @@ class HyperTraining:
model=None, model=None,
title_append="", title_append="",
subtitle="", subtitle="",
mode: Literal["eye", "head"] = "head", mode: Literal["eye", "head", "powers"] = "head",
show=True, show=False,
): ):
if mode == "powers":
input_data = torch.ones(
1, 2 * self.data_settings.output_size, dtype=getattr(torch, self.data_settings.dtype)
).to(self.pytorch_settings.device)
model = model.to(self.pytorch_settings.device)
model.eval()
with torch.no_grad():
_, powers = model(input_data, trace_powers=True)
powers = [power.item() for power in powers]
layer_names = ["input", *[str(x).split("(")[0] for x in model._layers._modules.values()]]
# remove dropout layers
mask = [1 if "Dropout" not in layer_name else 0 for layer_name in layer_names]
layer_names = [layer_name for layer_name, m in zip(layer_names, mask) if m]
powers = [power for power, m in zip(powers, mask) if m]
fig = self._plot_model_response_powers(
powers, layer_names, title_append=title_append, subtitle=subtitle, show=show
)
return fig
data_settings_backup = copy.deepcopy(self.data_settings) data_settings_backup = copy.deepcopy(self.data_settings)
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings) pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
self.data_settings.drop_first = 100*128 self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
self.data_settings.shuffle = False self.data_settings.shuffle = False
self.data_settings.train_split = 1.0 self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = ( self.pytorch_settings.batchsize = (
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
) )
plot_loader, _ = self.get_sliced_data(trial, override={"num_symbols": self.pytorch_settings.batchsize}) config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path
fiber_length = int(float(str(config_path).split('-')[-7])/1000)
plot_loader, _ = self.get_sliced_data(
trial,
override={
"num_symbols": self.pytorch_settings.batchsize,
"config_path": config_path,
}
)
self.data_settings = data_settings_backup self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup self.pytorch_settings = pytorch_settings_backup
fiber_in, fiber_out, regen = self.run_model(model, plot_loader) fiber_in, fiber_out, regen, timestamps = self.run_model(model, plot_loader)
fiber_in = fiber_in.view(-1, 2) fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2) fiber_out = fiber_out.view(-1, 2)
regen = regen.view(-1, 2) regen = regen.view(-1, 2)
@@ -703,6 +847,7 @@ class HyperTraining:
fiber_in = fiber_in.numpy() fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy() fiber_out = fiber_out.numpy()
regen = regen.numpy() regen = regen.numpy()
timestamps = timestamps.numpy()
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987 # https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463 # https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
@@ -713,9 +858,10 @@ class HyperTraining:
fiber_in, fiber_in,
fiber_out, fiber_out,
regen, regen,
timestamps=timestamps,
labels=("fiber in", "fiber out", "regen"), labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol, sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append, title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle, subtitle=subtitle,
show=show, show=show,
) )
@@ -725,9 +871,10 @@ class HyperTraining:
fiber_in, fiber_in,
fiber_out, fiber_out,
regen, regen,
timestamps=timestamps,
labels=("fiber in", "fiber out", "regen"), labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol, sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append, title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle, subtitle=subtitle,
show=show, show=show,
) )
@@ -739,7 +886,7 @@ class HyperTraining:
@staticmethod @staticmethod
def build_title(trial: optuna.trial.Trial): def build_title(trial: optuna.trial.Trial):
title_append = f"for trial {trial.number}" title_append = f"at epoch {trial.user_attrs.get("epoch", -1)} for trial {trial.number}"
model_n_hidden_layers = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_n_hidden_layers", 0) model_n_hidden_layers = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_n_hidden_layers", 0)
input_dim = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_input_dim", 0) input_dim = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_input_dim", 0)
model_dims = [ model_dims = [

View File

@@ -16,7 +16,7 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn # import torch.nn as nn
# import torch.nn.functional as F # mse_loss doesn't support complex numbers # import torch.nn.functional as F # mse_loss doesn't support complex numbers
import torch.optim as optim import torch.optim as optim
@@ -47,88 +47,6 @@ from .settings import (
) )
class regenerator(nn.Module):
def __init__(
self,
*dims,
layer_function=util.complexNN.ONN,
layer_kwargs: dict | None = None,
layer_parametrizations: list[dict] = None,
# [
# {
# "tensor_name": "weight",
# "parametrization": util.complexNN.Unitary,
# },
# {
# "tensor_name": "scale",
# "parametrization": util.complexNN.Clamp,
# },
# ],
activation_function=util.complexNN.Pow,
dtype=torch.float64,
dropout_prob=0.01,
scale=False,
**kwargs,
):
super(regenerator, self).__init__()
if len(dims) == 0:
try:
dims = kwargs["dims"]
except KeyError:
raise ValueError("dims must be provided")
self._n_hidden_layers = len(dims) - 2
self._layers = nn.Sequential()
if layer_kwargs is None:
layer_kwargs = {}
# self.powers = []
for i in range(self._n_hidden_layers + 1):
if scale:
self._layers.append(util.complexNN.Scale(dims[i]))
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs))
if i < self._n_hidden_layers:
if dropout_prob is not None:
self._layers.append(util.complexNN.DropoutComplex(p=dropout_prob))
self._layers.append(activation_function(bias=True, size=dims[i + 1]))
self._layers.append(util.complexNN.Scale(dims[-1]))
# add parametrizations
if layer_parametrizations is not None:
for layer in self._layers:
for layer_parametrization in layer_parametrizations:
tensor_name = layer_parametrization.get("tensor_name", None)
parametrization = layer_parametrization.get("parametrization", None)
param_kwargs = layer_parametrization.get("kwargs", {})
if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None:
parametrization(layer, tensor_name, **param_kwargs)
# def __call__(self, input_x, **kwargs):
# return self.forward(input_x, **kwargs)
def forward(self, input_x, trace_powers=False):
x = input_x
if trace_powers:
powers = [x.abs().square().sum()]
# check if tracing
if torch.jit.is_tracing():
for layer in self._layers:
x = layer(x)
if trace_powers:
powers.append(x.abs().square().sum())
else:
# with torch.nn.utils.parametrize.cached():
for layer in self._layers:
x = layer(x)
if trace_powers:
powers.append(x.abs().square().sum())
if trace_powers:
return x, powers
return x
def traverse_dict_update(target, source): def traverse_dict_update(target, source):
for k, v in source.items(): for k, v in source.items():
if isinstance(v, dict): if isinstance(v, dict):
@@ -164,7 +82,7 @@ class Trainer:
ModelSettings, ModelSettings,
OptimizerSettings, OptimizerSettings,
PytorchSettings, PytorchSettings,
regenerator, util.complexNN.regenerator,
torch.nn.utils.parametrizations.orthogonal, torch.nn.utils.parametrizations.orthogonal,
]) ])
if self.resume: if self.resume:
@@ -264,7 +182,7 @@ class Trainer:
dtype = self.model_kwargs["dtype"] dtype = self.model_kwargs["dtype"]
# dims = self.model_kwargs.pop("dims") # dims = self.model_kwargs.pop("dims")
self.model = regenerator(**self.model_kwargs) self.model = util.complexNN.regenerator(**self.model_kwargs)
if self.writer is not None: if self.writer is not None:
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype)) self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
@@ -364,7 +282,7 @@ class Trainer:
task = progress.add_task("-.---e--", total=len(train_loader)) task = progress.add_task("-.---e--", total=len(train_loader))
progress.start() progress.start()
# running_loss2 = 0.0 running_loss2 = 0.0
running_loss = 0.0 running_loss = 0.0
self.model.train() self.model.train()
loader_len = len(train_loader) loader_len = len(train_loader)
@@ -379,23 +297,24 @@ class Trainer:
loss_value = loss.item() loss_value = loss.item()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# running_loss2 += loss_value running_loss2 += loss_value
running_loss += loss_value running_loss += loss_value
if enable_progress: if enable_progress:
progress.update(task, advance=1, description=f"{running_loss/(batch_idx+1):.3e}") progress.update(task, advance=1, description=f"{loss_value:.3e}")
if batch_idx % self.pytorch_settings.write_every == 0: if batch_idx % self.pytorch_settings.write_every == 0:
self.writer.add_scalar( self.writer.add_scalar(
"training loss", "training loss",
running_loss / (batch_idx + 1), running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
epoch * loader_len + batch_idx, epoch + batch_idx/loader_len,
) )
running_loss2 = 0.0
if enable_progress: if enable_progress:
progress.stop() progress.stop()
return running_loss / (batch_idx + 1) return running_loss / len(train_loader)
def eval_model(self, valid_loader, epoch, enable_progress=True): def eval_model(self, valid_loader, epoch, enable_progress=True):
if enable_progress: if enable_progress:
@@ -418,7 +337,7 @@ class Trainer:
self.model.eval() self.model.eval()
running_error = 0 running_error = 0
with torch.no_grad(): with torch.no_grad():
for batch_idx, (x, y, _) in enumerate(valid_loader): for _, (x, y, _) in enumerate(valid_loader):
x, y = ( x, y = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
@@ -429,9 +348,9 @@ class Trainer:
running_error += error_value running_error += error_value
if enable_progress: if enable_progress:
progress.update(task, advance=1, description=f"{error_value/(batch_idx+1):.3e}") progress.update(task, advance=1, description=f"{error_value:.3e}")
running_error /= (batch_idx+1) running_error = running_error/len(valid_loader)
self.writer.add_scalar( self.writer.add_scalar(
"eval loss", "eval loss",
@@ -858,7 +777,7 @@ class Trainer:
self.pytorch_settings.batchsize = ( self.pytorch_settings.batchsize = (
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
) )
config_path = random.choice(self.data_settings.config_path) config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path
fiber_length = int(float(str(config_path).split('-')[-7])/1000) fiber_length = int(float(str(config_path).split('-')[-7])/1000)
plot_loader, _ = self.get_sliced_data( plot_loader, _ = self.get_sliced_data(
override={ override={

View File

@@ -1,3 +1,4 @@
from pathlib import Path
import matplotlib import matplotlib
import numpy as np import numpy as np
import torch import torch
@@ -22,8 +23,8 @@ global_settings = GlobalSettings(
) )
data_settings = DataSettings( data_settings = DataSettings(
# config_path="data/*-128-16384-50000-0-0-17-0-PAM4-0.ini", config_path="data/20241204-13*-128-16384-100000-0-0-17-0-PAM4-0.ini",
config_path=[f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in (40000, 50000, 60000)], # config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
dtype="complex64", dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232 symbols=13, # study: single_core_regen_20241123_011232
@@ -52,8 +53,8 @@ model_settings = ModelSettings(
output_dim=2, output_dim=2,
n_hidden_layers=4, n_hidden_layers=4,
overrides={ overrides={
"n_hidden_nodes_0": 4, "n_hidden_nodes_0": 8,
"n_hidden_nodes_1": 4, "n_hidden_nodes_1": 8,
"n_hidden_nodes_2": 4, "n_hidden_nodes_2": 4,
"n_hidden_nodes_3": 4, "n_hidden_nodes_3": 4,
}, },
@@ -61,7 +62,7 @@ model_settings = ModelSettings(
dropout_prob=0.01, dropout_prob=0.01,
model_layer_function="ONNRect", model_layer_function="ONNRect",
model_layer_kwargs={"square": True}, model_layer_kwargs={"square": True},
scale=True, scale=False,
model_layer_parametrizations=[ model_layer_parametrizations=[
{ {
"tensor_name": "weight", "tensor_name": "weight",
@@ -113,7 +114,7 @@ model_settings = ModelSettings(
optimizer_settings = OptimizerSettings( optimizer_settings = OptimizerSettings(
optimizer="AdamW", optimizer="AdamW",
optimizer_kwargs={ optimizer_kwargs={
"lr": 0.05, "lr": 0.01,
"amsgrad": True, "amsgrad": True,
# "weight_decay": 1e-7, # "weight_decay": 1e-7,
}, },
@@ -142,8 +143,9 @@ def save_dict_to_file(dictionary, filename):
json.dump(dictionary, f, indent=4) json.dump(dictionary, f, indent=4)
def sweep_lengths(*lengths, model=None): def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
assert model is not None, "Model must be provided." assert model is not None, "Model must be provided."
assert data_glob is not None, "Data glob must be provided."
model = model model = model
fiber_ins = {} fiber_ins = {}
@@ -151,19 +153,31 @@ def sweep_lengths(*lengths, model=None):
regens = {} regens = {}
timestampss = {} timestampss = {}
for length in lengths: trainer = Trainer(
trainer = Trainer(
checkpoint_path=model, checkpoint_path=model,
settings_override={
"data_settings": {
"config_path": f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini",
"train_split": 1,
"shuffle": True,
}
},
) )
trainer.define_model() trainer.define_model()
loader, _ = trainer.get_sliced_data()
for length in lengths:
data_glob_length = data_glob.replace("{length}", str(length))
files = list(Path.cwd().glob(data_glob_length))
if len(files) == 0:
continue
if strategy == "newest":
sorted_kwargs = {
'key': lambda x: x.stat().st_mtime,
'reverse': True,
}
elif strategy == "oldest":
sorted_kwargs = {
'key': lambda x: x.stat().st_mtime,
'reverse': False,
}
else:
raise ValueError(f"Unknown strategy {strategy}.")
file = sorted(files, **sorted_kwargs)[0]
loader, _ = trainer.get_sliced_data(override={"config_path": file})
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader) fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
fiber_ins[length] = fiber_in fiber_ins[length] = fiber_in
@@ -171,17 +185,23 @@ def sweep_lengths(*lengths, model=None):
regens[length] = regen regens[length] = regen
timestampss[length] = timestamps timestampss[length] = timestamps
data = torch.zeros(2 * len(lengths), 2, fiber_out.shape[0]) data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0])
channel_names = ["" for _ in range(2 * len(lengths))] channel_names = ["" for _ in range(2 * len(timestampss.keys())+2)]
for li, length in enumerate(lengths): data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128
data[2 * li, 0, :] = timestampss[length] / 128 data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square()
data[2 * li, 1, :] = regens[length][:, 0].abs().square()
data[2 * li + 1, 0, :] = timestampss[length] / 128
data[2 * li + 1, 1, :] = regens[length][:, 1].abs().square()
channel_names[2 * li] = f"regen x {length}" channel_names[1] = "fiber in x"
channel_names[2 * li + 1] = f"regen y {length}"
for li, length in enumerate(timestampss.keys()):
data[2+2 * li, 0, :] = timestampss[length] / 128
data[2+2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
data[2+2 * li + 1, 0, :] = timestampss[length] / 128
data[2+2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
channel_names[2+2 * li+1] = f"regen x {length}"
channel_names[2+2 * li] = f"fiber out x {length}"
# get current backend # get current backend
backend = matplotlib.get_backend() backend = matplotlib.get_backend()
@@ -189,28 +209,30 @@ def sweep_lengths(*lengths, model=None):
matplotlib.use("TkCairo") matplotlib.use("TkCairo")
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names) eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
print_attrs = ("channel", "success", "min_area") print_attrs = ("channel_name", "success", "min_area")
with np.printoptions(precision=3, suppress=True, formatter={'float': '{:0.3e}'.format}): with np.printoptions(precision=3, suppress=True, formatter={'float': '{:0.3e}'.format}):
for result in eye.eye_stats: for result in eye.eye_stats:
print_dict = {attr: result[attr] for attr in print_attrs} print_dict = {attr: result[attr] for attr in print_attrs}
rprint(print_dict) rprint(print_dict)
rprint() rprint()
eye.plot() eye.plot(all_stats=False)
matplotlib.use(backend) matplotlib.use(backend)
if __name__ == "__main__": if __name__ == "__main__":
# sweep_lengths(30000, 40000, 50000, 60000, 70000, model=".models/best_20241202_143149.tar") lengths = range(90000, 100000+10000, 10000)
# lengths = [100000]
sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest")
trainer = Trainer( # trainer = Trainer(
global_settings=global_settings, # global_settings=global_settings,
data_settings=data_settings, # data_settings=data_settings,
pytorch_settings=pytorch_settings, # pytorch_settings=pytorch_settings,
model_settings=model_settings, # model_settings=model_settings,
optimizer_settings=optimizer_settings, # optimizer_settings=optimizer_settings,
# checkpoint_path=".models/best_20241202_143149.tar", # # checkpoint_path=".models/best_20241202_143149.tar",
# 20241202_143149 # # 20241202_143149
) # )
trainer.train() # trainer.train()

View File

@@ -39,7 +39,7 @@ import numpy as np
if __name__ == "__main__": if __name__ == "__main__":
dataset = FiberRegenerationDataset("data/20241115-175517-128-16384-10000-0-0-17-0-PAM4-0.ini", symbols=13, drop_first=100, output_dim=26, num_symbols=100) dataset = FiberRegenerationDataset("data/202412*-128-16384-50000-0-0-17-0-PAM4-0.ini", symbols=13, drop_first=100, output_dim=26, num_symbols=100)
loader = DataLoader(dataset, batch_size=10, shuffle=True) loader = DataLoader(dataset, batch_size=10, shuffle=True)

View File

@@ -569,6 +569,78 @@ class ZReLU(nn.Module):
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2) return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2)
else: else:
return torch.relu(x) return torch.relu(x)
class regenerator(nn.Module):
def __init__(
self,
*dims,
layer_function=ONN,
layer_kwargs: dict | None = None,
layer_parametrizations: list[dict] = None,
activation_function=Pow,
dtype=torch.float64,
dropout_prob=0.01,
scale=False,
**kwargs,
):
super(regenerator, self).__init__()
if len(dims) == 0:
try:
dims = kwargs["dims"]
except KeyError:
raise ValueError("dims must be provided")
self._n_hidden_layers = len(dims) - 2
self._layers = nn.Sequential()
if layer_kwargs is None:
layer_kwargs = {}
# self.powers = []
for i in range(self._n_hidden_layers + 1):
if scale:
self._layers.append(Scale(dims[i]))
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs))
if i < self._n_hidden_layers:
if dropout_prob is not None:
self._layers.append(DropoutComplex(p=dropout_prob))
self._layers.append(activation_function(bias=True, size=dims[i + 1]))
self._layers.append(Scale(dims[-1]))
# add parametrizations
if layer_parametrizations is not None:
for layer in self._layers:
for layer_parametrization in layer_parametrizations:
tensor_name = layer_parametrization.get("tensor_name", None)
parametrization = layer_parametrization.get("parametrization", None)
param_kwargs = layer_parametrization.get("kwargs", {})
if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None:
parametrization(layer, tensor_name, **param_kwargs)
# def __call__(self, input_x, **kwargs):
# return self.forward(input_x, **kwargs)
def forward(self, input_x, trace_powers=False):
x = input_x
if trace_powers:
powers = [x.abs().square().sum()]
# check if tracing
if torch.jit.is_tracing():
for layer in self._layers:
x = layer(x)
if trace_powers:
powers.append(x.abs().square().sum())
else:
# with torch.nn.utils.parametrize.cached():
for layer in self._layers:
x = layer(x)
if trace_powers:
powers.append(x.abs().square().sum())
if trace_powers:
return x, powers
return x
__all__ = [ __all__ = [

View File

@@ -3,6 +3,7 @@ from matplotlib.colors import LinearSegmentedColormap
import numpy as np import numpy as np
from scipy.cluster.vq import kmeans2 from scipy.cluster.vq import kmeans2
import warnings import warnings
import multiprocessing
from rich.traceback import install from rich.traceback import install
from rich import pretty from rich import pretty
@@ -67,7 +68,7 @@ def generate_wavelet(sps, oversample=3):
class eye_diagram: class eye_diagram:
def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4): def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4, multithreaded=True):
# data has shape [channels, 2, samples] # data has shape [channels, 2, samples]
# each sample has a timestamp and a value # each sample has a timestamp and a value
if data.ndim == 2: if data.ndim == 2:
@@ -79,28 +80,38 @@ class eye_diagram:
self.eye_stats = [{"success": False} for _ in range(self.channels)] self.eye_stats = [{"success": False} for _ in range(self.channels)]
self.horizontal_bins = horizontal_bins self.horizontal_bins = horizontal_bins
self.vertical_bins = vertical_bins self.vertical_bins = vertical_bins
self.multi_threaded = multithreaded
self.eye_built = False self.eye_built = False
self.analyse(self.n_levels) self.analyse()
def generate_eye_data(self): def generate_eye_data(self):
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False) self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
self.y_bins = np.zeros((self.channels, self.vertical_bins)) self.y_bins = np.zeros((self.channels, self.vertical_bins))
self.eye_data = np.zeros((self.channels, self.vertical_bins, self.horizontal_bins)) self.eye_data = np.zeros((self.channels, self.vertical_bins, self.horizontal_bins))
for i in range(self.channels): datas = [self.raw_data[i] for i in range(self.channels)]
data_min = np.min(self.raw_data[i, 1, :]) if self.multi_threaded:
data_max = np.max(self.raw_data[i, 1, :]) with multiprocessing.Pool() as pool:
self.y_bins[i] = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False) results = pool.map(self.generate_eye_data_single, datas)
for i, result in enumerate(results):
t_vals = self.raw_data[i, 0, :] % 2 self.eye_data[i], self.y_bins[i] = result
val_vals = self.raw_data[i, 1, :] else:
for i, data in enumerate(datas):
x_indices = np.digitize(t_vals, self.x_bins) - 1 self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
y_indices = np.digitize(val_vals, self.y_bins[i]) - 1
np.add.at(self.eye_data[i], (y_indices, x_indices), 1)
self.eye_built = True self.eye_built = True
def generate_eye_data_single(self, data):
eye_data = np.zeros((self.vertical_bins, self.horizontal_bins))
data_min = np.min(data[1, :])
data_max = np.max(data[1, :])
y_bins = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
t_vals = data[0, :] % 2
val_vals = data[1, :]
x_indices = np.digitize(t_vals, self.x_bins) - 1
y_indices = np.digitize(val_vals, y_bins) - 1
np.add.at(eye_data, (y_indices, x_indices), 1)
return eye_data, y_bins
def plot(self, title="Eye Diagram", stats=True, show=True): def plot(self, title="Eye Diagram", stats=True, all_stats=True, show=True):
if not self.eye_built: if not self.eye_built:
self.generate_eye_data() self.generate_eye_data()
cmap = LinearSegmentedColormap.from_list( cmap = LinearSegmentedColormap.from_list(
@@ -118,8 +129,10 @@ class eye_diagram:
ax = np.atleast_1d(ax).transpose().flatten() ax = np.atleast_1d(ax).transpose().flatten()
for i in range(self.channels): for i in range(self.channels):
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}") ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}")
ax[i].set_xlabel("Symbol") if (i+1) % rows == 0:
ax[i].set_ylabel("Amplitude") ax[i].set_xlabel("Symbol")
if i < rows:
ax[i].set_ylabel("Amplitude")
ax[i].grid() ax[i].grid()
ax[i].imshow( ax[i].imshow(
self.eye_data[i], self.eye_data[i],
@@ -134,67 +147,6 @@ class eye_diagram:
yspan = ymax - ymin yspan = ymax - ymin
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan)) ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
if stats and self.eye_stats[i]["success"]: if stats and self.eye_stats[i]["success"]:
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
ax[i].set_yticks(self.eye_stats[i]["levels"])
# add arrows for amplitudes
for j in range(len(self.eye_stats[i]["amplitudes"])):
ax[i].annotate(
"",
xy=(0.05, self.eye_stats[i]["levels"][j]),
xytext=(0.05, self.eye_stats[i]["levels"][j + 1]),
arrowprops=dict(arrowstyle="<->", facecolor="black"),
)
ax[i].annotate(
f"{self.eye_stats[i]['amplitudes'][j]:.2e}",
xy=(0.06, (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2),
)
# add arrows for eye heights
for j in range(len(self.eye_stats[i]["heights"])):
try:
bot = np.max(self.eye_stats[i]["amplitude_clusters"][j])
top = np.min(self.eye_stats[i]["amplitude_clusters"][j + 1])
ax[i].annotate(
"",
xy=(self.eye_stats[i]["time_midpoint"], bot),
xytext=(self.eye_stats[i]["time_midpoint"], top),
arrowprops=dict(arrowstyle="<->", facecolor="black"),
)
ax[i].annotate(
f"{self.eye_stats[i]['heights'][j]:.2e}",
xy=(self.eye_stats[i]["time_midpoint"] + 0.015, (bot + top) / 2 + 0.04),
)
except (ValueError, IndexError):
pass
# add arrows for eye widths
for j in range(len(self.eye_stats[i]["widths"])):
try:
left = np.max(self.eye_stats[i]["time_clusters"][j][0])
right = np.min(self.eye_stats[i]["time_clusters"][j][1])
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
ax[i].annotate(
"",
xy=(left, vertical),
xytext=(right, vertical),
arrowprops=dict(arrowstyle="<->", facecolor="black"),
)
ax[i].annotate(
f"{self.eye_stats[i]['widths'][j]:.2e}",
xy=((left + right) / 2 - 0.15, vertical + 0.01),
)
except (ValueError, IndexError):
pass
# add area
for j in range(len(self.eye_stats[i]["areas"])):
horizontal = self.eye_stats[i]["time_midpoint"]
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
ax[i].annotate(
f"{self.eye_stats[i]['areas'][j]:.2e}",
xy=(horizontal + 0.035, vertical - 0.07),
)
# add min_area above the plot # add min_area above the plot
ax[i].annotate( ax[i].annotate(
f"Min Area: {self.eye_stats[i]['min_area']:.2e}", f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
@@ -202,62 +154,142 @@ class eye_diagram:
# xycoords="axes fraction", # xycoords="axes fraction",
ha="left", ha="left",
va="center", va="center",
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
) )
if all_stats:
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
ax[i].set_yticks(self.eye_stats[i]["levels"])
# add arrows for amplitudes
for j in range(len(self.eye_stats[i]["amplitudes"])):
ax[i].annotate(
"",
xy=(0.05, self.eye_stats[i]["levels"][j]),
xytext=(0.05, self.eye_stats[i]["levels"][j + 1]),
arrowprops=dict(arrowstyle="<->", facecolor="black"),
)
ax[i].annotate(
f"{self.eye_stats[i]['amplitudes'][j]:.2e}",
xy=(0.06, (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2),
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
)
# add arrows for eye heights
for j in range(len(self.eye_stats[i]["heights"])):
try:
bot = np.max(self.eye_stats[i]["amplitude_clusters"][j])
top = np.min(self.eye_stats[i]["amplitude_clusters"][j + 1])
ax[i].annotate(
"",
xy=(self.eye_stats[i]["time_midpoint"], bot),
xytext=(self.eye_stats[i]["time_midpoint"], top),
arrowprops=dict(arrowstyle="<->", facecolor="black"),
)
ax[i].annotate(
f"{self.eye_stats[i]['heights'][j]:.2e}",
xy=(self.eye_stats[i]["time_midpoint"] + 0.015, (bot + top) / 2 + 0.04),
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
)
except (ValueError, IndexError):
pass
# add arrows for eye widths
for j in range(len(self.eye_stats[i]["widths"])):
try:
left = np.max(self.eye_stats[i]["time_clusters"][j][0])
right = np.min(self.eye_stats[i]["time_clusters"][j][1])
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
ax[i].annotate(
"",
xy=(left, vertical),
xytext=(right, vertical),
arrowprops=dict(arrowstyle="<->", facecolor="black"),
)
ax[i].annotate(
f"{self.eye_stats[i]['widths'][j]:.2e}",
xy=((left + right) / 2 - 0.15, vertical + 0.01),
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
)
except (ValueError, IndexError):
pass
# add area
for j in range(len(self.eye_stats[i]["areas"])):
horizontal = self.eye_stats[i]["time_midpoint"]
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
ax[i].annotate(
f"{self.eye_stats[i]['areas'][j]:.2e}",
xy=(horizontal + 0.035, vertical - 0.07),
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
)
fig.tight_layout() fig.tight_layout()
if show: if show:
plt.show() plt.show()
return fig return fig
def analyse(self, n_levels=4): def analyse_single(self, data, index):
warnings.filterwarnings("error") warnings.filterwarnings("error")
for i in range(self.channels): eye_stats = {}
self.eye_stats[i]["channel"] = str(i+1) if self.channel_names is None else self.channel_names[i] eye_stats["channel_name"] = str(index+1) if self.channel_names is None else self.channel_names[index]
try: try:
approx_levels = eye_diagram.approximate_levels(self.raw_data[i], n_levels) approx_levels = eye_diagram.approximate_levels(data, self.n_levels)
time_bounds = eye_diagram.calculate_time_bounds(self.raw_data[i], approx_levels) time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels)
self.eye_stats[i]["time_midpoint"] = (time_bounds[0] + time_bounds[1]) / 2 eye_stats["time_midpoint"] = (time_bounds[0] + time_bounds[1]) / 2
self.eye_stats[i]["levels"], self.eye_stats[i]["amplitude_clusters"] = eye_diagram.calculate_levels( eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels(
self.raw_data[i], approx_levels, time_bounds data, approx_levels, time_bounds
) )
self.eye_stats[i]["amplitudes"] = np.diff(self.eye_stats[i]["levels"]) eye_stats["amplitudes"] = np.diff(eye_stats["levels"])
self.eye_stats[i]["heights"] = eye_diagram.calculate_eye_heights( eye_stats["heights"] = eye_diagram.calculate_eye_heights(
self.eye_stats[i]["amplitude_clusters"] eye_stats["amplitude_clusters"]
) )
self.eye_stats[i]["widths"], self.eye_stats[i]["time_clusters"] = eye_diagram.calculate_eye_widths( eye_stats["widths"], eye_stats["time_clusters"] = eye_diagram.calculate_eye_widths(
self.raw_data[i], self.eye_stats[i]["levels"] data, eye_stats["levels"]
) )
# # check if time clusters are valid (upper bound > time_midpoint > lower bound) # # check if time clusters are valid (upper bound > time_midpoint > lower bound)
# # if not: raise ValueError # # if not: raise ValueError
# for j in range(len(self.eye_stats[i]['time_clusters'])): # for j in range(len(eye_stats['time_clusters'])):
# if not (np.max(self.eye_stats[i]['time_clusters'][j][0]) < self.eye_stats[i]["time_midpoint"] < np.min(self.eye_stats[i]['time_clusters'][j][1])): # if not (np.max(eye_stats['time_clusters'][j][0]) < eye_stats["time_midpoint"] < np.min(eye_stats['time_clusters'][j][1])):
# raise ValueError # raise ValueError
self.eye_stats[i]["areas"] = self.eye_stats[i]["heights"] * self.eye_stats[i]["widths"] eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"]
self.eye_stats[i]["mean_area"] = np.mean(self.eye_stats[i]["areas"]) eye_stats["mean_area"] = np.mean(eye_stats["areas"])
self.eye_stats[i]["min_area"] = np.min(self.eye_stats[i]["areas"]) eye_stats["min_area"] = np.min(eye_stats["areas"])
self.eye_stats[i]["success"] = True
except (RuntimeWarning, UserWarning, ValueError):
self.eye_stats[i]["success"] = False
self.eye_stats[i]["time_midpoint"] = 0
self.eye_stats[i]["levels"] = np.zeros(n_levels)
self.eye_stats[i]["amplitude_clusters"] = []
self.eye_stats[i]["amplitudes"] = np.zeros(n_levels - 1)
self.eye_stats[i]["heights"] = np.zeros(n_levels - 1)
self.eye_stats[i]["widths"] = np.zeros(n_levels - 1)
self.eye_stats[i]["areas"] = np.zeros(n_levels - 1)
self.eye_stats[i]["mean_area"] = 0
self.eye_stats[i]["min_area"] = 0
eye_stats["success"] = True
except (RuntimeWarning, UserWarning, ValueError):
eye_stats["success"] = False
eye_stats["time_midpoint"] = 0
eye_stats["levels"] = np.zeros(self.n_levels)
eye_stats["amplitude_clusters"] = []
eye_stats["amplitudes"] = np.zeros(self.n_levels - 1)
eye_stats["heights"] = np.zeros(self.n_levels - 1)
eye_stats["widths"] = np.zeros(self.n_levels - 1)
eye_stats["areas"] = np.zeros(self.n_levels - 1)
eye_stats["mean_area"] = 0
eye_stats["min_area"] = 0
warnings.resetwarnings() warnings.resetwarnings()
return eye_stats
def analyse(self):
self.eye_stats = []
if self.multi_threaded:
with multiprocessing.Pool() as pool:
results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.channels)])
for i, result in enumerate(results):
self.eye_stats.append(result)
else:
for i in range(self.channels):
self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
@staticmethod @staticmethod
def approximate_levels(data, levels): def approximate_levels(data, levels):

82160
src/visualization/viz.ipynb Normal file

File diff suppressed because it is too large Load Diff