add regenerator class and update dataset configurations for model training
This commit is contained in:
@@ -1,7 +1,16 @@
|
|||||||
import copy
|
import copy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import random
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
import matplotlib
|
||||||
|
from matplotlib.colors import LinearSegmentedColormap
|
||||||
|
import torch.nn.utils.parametrize
|
||||||
|
|
||||||
|
try:
|
||||||
|
matplotlib.use("cairo")
|
||||||
|
except ImportError:
|
||||||
|
matplotlib.use("Agg")
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -11,7 +20,7 @@ import optuna
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
# import torch.nn as nn
|
||||||
|
|
||||||
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
|
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
@@ -19,27 +28,9 @@ import torch.utils.data
|
|||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
# from rich.progress import (
|
|
||||||
# Progress,
|
|
||||||
# TextColumn,
|
|
||||||
# BarColumn,
|
|
||||||
# TaskProgressColumn,
|
|
||||||
# TimeRemainingColumn,
|
|
||||||
# MofNCompleteColumn,
|
|
||||||
# TimeElapsedColumn,
|
|
||||||
# )
|
|
||||||
# from rich.console import Console
|
|
||||||
# from rich import print as rprint
|
|
||||||
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
from util.datasets import FiberRegenerationDataset
|
from util.datasets import FiberRegenerationDataset
|
||||||
|
|
||||||
# from util.optuna_helpers import (
|
|
||||||
# suggest_categorical_optional, # noqa: F401
|
|
||||||
# suggest_float_optional, # noqa: F401
|
|
||||||
# suggest_int_optional, # noqa: F401
|
|
||||||
# )
|
|
||||||
from util.optuna_helpers import install_optional_suggests
|
from util.optuna_helpers import install_optional_suggests
|
||||||
import util
|
import util
|
||||||
|
|
||||||
@@ -65,7 +56,6 @@ class HyperTraining:
|
|||||||
model_settings,
|
model_settings,
|
||||||
optimizer_settings,
|
optimizer_settings,
|
||||||
optuna_settings,
|
optuna_settings,
|
||||||
# console=None,
|
|
||||||
):
|
):
|
||||||
self.global_settings: GlobalSettings = global_settings
|
self.global_settings: GlobalSettings = global_settings
|
||||||
self.data_settings: DataSettings = data_settings
|
self.data_settings: DataSettings = data_settings
|
||||||
@@ -75,11 +65,8 @@ class HyperTraining:
|
|||||||
self.optuna_settings: OptunaSettings = optuna_settings
|
self.optuna_settings: OptunaSettings = optuna_settings
|
||||||
self.processes = None
|
self.processes = None
|
||||||
|
|
||||||
# self.console = console or Console()
|
|
||||||
|
|
||||||
# set some extra settings to make the code more readable
|
|
||||||
self._extra_optuna_settings()
|
self._extra_optuna_settings()
|
||||||
self.stop_study = True
|
self.stop_study = False
|
||||||
|
|
||||||
def setup_tb_writer(self, study_name=None, append=None):
|
def setup_tb_writer(self, study_name=None, append=None):
|
||||||
log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name)
|
log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name)
|
||||||
@@ -229,7 +216,7 @@ class HyperTraining:
|
|||||||
self.optuna_settings._parallel = self.optuna_settings._n_threads > 1
|
self.optuna_settings._parallel = self.optuna_settings._n_threads > 1
|
||||||
|
|
||||||
def define_model(self, trial: optuna.Trial, writer=None):
|
def define_model(self, trial: optuna.Trial, writer=None):
|
||||||
n_layers = trial.suggest_int_optional("model_n_hidden_layers", self.model_settings.n_hidden_layers)
|
n_hidden_layers = trial.suggest_int_optional("model_n_hidden_layers", self.model_settings.n_hidden_layers)
|
||||||
|
|
||||||
input_dim = trial.suggest_int_optional(
|
input_dim = trial.suggest_int_optional(
|
||||||
"model_input_dim",
|
"model_input_dim",
|
||||||
@@ -245,32 +232,41 @@ class HyperTraining:
|
|||||||
dtype = getattr(torch, dtype)
|
dtype = getattr(torch, dtype)
|
||||||
|
|
||||||
afunc = trial.suggest_categorical_optional("model_activation_func", self.model_settings.model_activation_func)
|
afunc = trial.suggest_categorical_optional("model_activation_func", self.model_settings.model_activation_func)
|
||||||
# T0 = trial.suggest_float_optional("T0", self.model_settings.satabsT0 , log=True)
|
afunc = getattr(util.complexNN, afunc)
|
||||||
|
layer_func = trial.suggest_categorical_optional("model_layer_function", self.model_settings.model_layer_function)
|
||||||
|
layer_func = getattr(util.complexNN, layer_func)
|
||||||
|
layer_parametrizations = self.model_settings.model_layer_parametrizations
|
||||||
|
|
||||||
layers = []
|
scale_layers = trial.suggest_categorical_optional("model_enable_scale_layers", self.model_settings.scale)
|
||||||
last_dim = input_dim
|
|
||||||
n_nodes = last_dim
|
|
||||||
for i in range(n_layers):
|
hidden_dims = []
|
||||||
|
for i in range(n_hidden_layers):
|
||||||
if hidden_dim_override := self.model_settings.overrides.get(f"n_hidden_nodes_{i}", False):
|
if hidden_dim_override := self.model_settings.overrides.get(f"n_hidden_nodes_{i}", False):
|
||||||
hidden_dim = trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override)
|
hidden_dims.append(trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override))
|
||||||
else:
|
else:
|
||||||
hidden_dim = trial.suggest_int_optional(
|
hidden_dims.append(trial.suggest_int_optional(
|
||||||
f"model_hidden_dim_{i}",
|
f"model_hidden_dim_{i}",
|
||||||
self.model_settings.n_hidden_nodes,
|
self.model_settings.n_hidden_nodes,
|
||||||
)
|
))
|
||||||
layers.append(util.complexNN.ONNRect(last_dim, hidden_dim, dtype=dtype))
|
|
||||||
last_dim = hidden_dim
|
model_kwargs = {
|
||||||
layers.append(getattr(util.complexNN, afunc)())
|
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
|
||||||
n_nodes += last_dim
|
"layer_function": layer_func,
|
||||||
|
"layer_parametrizations": layer_parametrizations,
|
||||||
layers.append(util.complexNN.ONNRect(last_dim, self.model_settings.output_dim, dtype=dtype))
|
"activation_function": afunc,
|
||||||
|
"dtype": dtype,
|
||||||
model = nn.Sequential(*layers)
|
"droupout_prob": self.model_settings.dropout_prob,
|
||||||
|
"scale": scale_layers,
|
||||||
|
}
|
||||||
|
|
||||||
|
model = util.complexNN.regenerator(**model_kwargs)
|
||||||
|
n_nodes = sum(hidden_dims)
|
||||||
|
|
||||||
if writer is not None:
|
if writer is not None:
|
||||||
writer.add_graph(model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False)
|
writer.add_graph(model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False)
|
||||||
|
|
||||||
n_params = sum(p.numel() for p in model.parameters())
|
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
trial.set_user_attr("model_n_params", n_params)
|
trial.set_user_attr("model_n_params", n_params)
|
||||||
trial.set_user_attr("model_n_nodes", n_nodes)
|
trial.set_user_attr("model_n_nodes", n_nodes)
|
||||||
|
|
||||||
@@ -384,7 +380,8 @@ class HyperTraining:
|
|||||||
running_loss2 = 0.0
|
running_loss2 = 0.0
|
||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
model.train()
|
model.train()
|
||||||
for batch_idx, (x, y) in enumerate(train_loader):
|
loader_len = len(train_loader)
|
||||||
|
for batch_idx, (x, y, _) in enumerate(train_loader):
|
||||||
if batch_idx >= self.optuna_settings._n_train_batches:
|
if batch_idx >= self.optuna_settings._n_train_batches:
|
||||||
break
|
break
|
||||||
model.zero_grad(set_to_none=True)
|
model.zero_grad(set_to_none=True)
|
||||||
@@ -408,14 +405,14 @@ class HyperTraining:
|
|||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
"training loss",
|
"training loss",
|
||||||
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
||||||
epoch * min(len(train_loader), self.optuna_settings._n_train_batches) + batch_idx,
|
epoch * min(loader_len, self.optuna_settings._n_train_batches) + batch_idx,
|
||||||
)
|
)
|
||||||
running_loss2 = 0.0
|
running_loss2 = 0.0
|
||||||
|
|
||||||
# if enable_progress:
|
# if enable_progress:
|
||||||
# progress.stop()
|
# progress.stop()
|
||||||
|
|
||||||
return running_loss / min(len(train_loader), self.optuna_settings._n_train_batches)
|
return running_loss / min(loader_len, self.optuna_settings._n_train_batches)
|
||||||
|
|
||||||
def eval_model(
|
def eval_model(
|
||||||
self,
|
self,
|
||||||
@@ -446,9 +443,8 @@ class HyperTraining:
|
|||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
running_error = 0
|
running_error = 0
|
||||||
running_error_2 = 0
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, (x, y) in enumerate(valid_loader):
|
for batch_idx, (x, y, _) in enumerate(valid_loader):
|
||||||
if batch_idx >= self.optuna_settings._n_valid_batches:
|
if batch_idx >= self.optuna_settings._n_valid_batches:
|
||||||
break
|
break
|
||||||
x, y = (
|
x, y = (
|
||||||
@@ -459,19 +455,6 @@ class HyperTraining:
|
|||||||
error = util.complexNN.complex_mse_loss(y_pred, y)
|
error = util.complexNN.complex_mse_loss(y_pred, y)
|
||||||
error_value = error.item()
|
error_value = error.item()
|
||||||
running_error += error_value
|
running_error += error_value
|
||||||
running_error_2 += error_value
|
|
||||||
|
|
||||||
# if enable_progress:
|
|
||||||
# progress.update(task, advance=1, description=f"{error_value:.3e}")
|
|
||||||
|
|
||||||
if writer is not None:
|
|
||||||
if batch_idx % self.pytorch_settings.write_every == 0:
|
|
||||||
writer.add_scalar(
|
|
||||||
"eval loss",
|
|
||||||
running_error_2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
|
||||||
epoch * min(len(valid_loader), self.optuna_settings._n_valid_batches) + batch_idx,
|
|
||||||
)
|
|
||||||
running_error_2 = 0.0
|
|
||||||
|
|
||||||
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
|
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
|
||||||
|
|
||||||
@@ -488,38 +471,73 @@ class HyperTraining:
|
|||||||
),
|
),
|
||||||
epoch + 1,
|
epoch + 1,
|
||||||
)
|
)
|
||||||
|
writer.add_figure(
|
||||||
|
"eye diagram",
|
||||||
|
self.plot_model_response(
|
||||||
|
trial,
|
||||||
|
model=self.model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=False,
|
||||||
|
mode="eye",
|
||||||
|
),
|
||||||
|
epoch + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
writer.add_figure(
|
||||||
|
"powers",
|
||||||
|
self.plot_model_response(
|
||||||
|
trial,
|
||||||
|
model=self.model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
mode="powers",
|
||||||
|
show=False,
|
||||||
|
),
|
||||||
|
epoch + 1,
|
||||||
|
)
|
||||||
|
|
||||||
# if enable_progress:
|
# if enable_progress:
|
||||||
# progress.stop()
|
# progress.stop()
|
||||||
|
|
||||||
return running_error
|
return running_error
|
||||||
|
|
||||||
def run_model(self, model, loader):
|
def run_model(self, model, loader, trace_powers=False):
|
||||||
model.eval()
|
model.eval()
|
||||||
xs = []
|
fiber_out = []
|
||||||
ys = []
|
fiber_in = []
|
||||||
y_preds = []
|
regen = []
|
||||||
|
timestamps = []
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model = model.to(self.pytorch_settings.device)
|
model = model.to(self.pytorch_settings.device)
|
||||||
for x, y in loader:
|
for x, y, timestamp in loader:
|
||||||
x, y = (
|
x, y = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
y_pred = model(x).cpu()
|
if trace_powers:
|
||||||
|
y_pred, powers = model(x, trace_powers).cpu()
|
||||||
|
else:
|
||||||
|
y_pred = model(x, trace_powers).cpu()
|
||||||
# x = x.cpu()
|
# x = x.cpu()
|
||||||
# y = y.cpu()
|
# y = y.cpu()
|
||||||
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||||
y = y.view(y.shape[0], -1, 2)
|
y = y.view(y.shape[0], -1, 2)
|
||||||
x = x.view(x.shape[0], -1, 2)
|
x = x.view(x.shape[0], -1, 2)
|
||||||
xs.append(x[:, 0, :].squeeze())
|
# timestamp = timestamp.view(-1, 1)
|
||||||
ys.append(y.squeeze())
|
fiber_out.append(x[:, x.shape[1] // 2, :].squeeze())
|
||||||
y_preds.append(y_pred.squeeze())
|
fiber_in.append(y.squeeze())
|
||||||
|
regen.append(y_pred.squeeze())
|
||||||
|
timestamps.append(timestamp.squeeze())
|
||||||
|
|
||||||
xs = torch.vstack(xs).cpu()
|
fiber_out = torch.vstack(fiber_out).cpu()
|
||||||
ys = torch.vstack(ys).cpu()
|
fiber_in = torch.vstack(fiber_in).cpu()
|
||||||
y_preds = torch.vstack(y_preds).cpu()
|
regen = torch.vstack(regen).cpu()
|
||||||
return ys, xs, y_preds
|
timestamps = torch.concat(timestamps).cpu()
|
||||||
|
if trace_powers:
|
||||||
|
return fiber_in, fiber_out, regen, timestamps, powers
|
||||||
|
return fiber_in, fiber_out, regen, timestamps
|
||||||
|
|
||||||
def objective(self, trial: optuna.Trial, plot_before=False):
|
def objective(self, trial: optuna.Trial, plot_before=False):
|
||||||
if self.stop_study:
|
if self.stop_study:
|
||||||
@@ -544,7 +562,32 @@ class HyperTraining:
|
|||||||
model=model,
|
model=model,
|
||||||
title_append=title_append,
|
title_append=title_append,
|
||||||
subtitle=subtitle,
|
subtitle=subtitle,
|
||||||
show=plot_before,
|
show=False,
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
writer.add_figure(
|
||||||
|
"eye diagram",
|
||||||
|
self.plot_model_response(
|
||||||
|
trial,
|
||||||
|
model=self.model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
mode="eye",
|
||||||
|
show=False,
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
writer.add_figure(
|
||||||
|
"powers",
|
||||||
|
self.plot_model_response(
|
||||||
|
trial,
|
||||||
|
model=self.model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
mode="powers",
|
||||||
|
show=False,
|
||||||
),
|
),
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
@@ -609,7 +652,9 @@ class HyperTraining:
|
|||||||
|
|
||||||
return error
|
return error
|
||||||
|
|
||||||
def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
def _plot_model_response_eye(
|
||||||
|
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
|
||||||
|
):
|
||||||
if sps is None:
|
if sps is None:
|
||||||
raise ValueError("sps must be provided")
|
raise ValueError("sps must be provided")
|
||||||
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||||
@@ -624,27 +669,84 @@ class HyperTraining:
|
|||||||
if not any(labels):
|
if not any(labels):
|
||||||
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
||||||
|
|
||||||
fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True)
|
x_bins = np.linspace(0, 2, 2 * sps, endpoint=False)
|
||||||
|
y_bins = np.zeros((2 * len(signals), 1000))
|
||||||
|
eye_data = np.zeros((2 * len(signals), 1000, 2 * sps))
|
||||||
|
# signals = [signal.cpu().numpy() for signal in signals]
|
||||||
|
for i in range(len(signals) * 2):
|
||||||
|
eye_signal = signals[i // 2][:, i % 2] # x, y, x, y, ...
|
||||||
|
eye_signal = np.real(np.square(np.abs(eye_signal)))
|
||||||
|
data_min = np.min(eye_signal)
|
||||||
|
data_max = np.max(eye_signal)
|
||||||
|
y_bins[i] = np.linspace(data_min, data_max, 1000, endpoint=False)
|
||||||
|
for j in range(len(timestamps)):
|
||||||
|
t = timestamps[j] / sps
|
||||||
|
val = eye_signal[j]
|
||||||
|
x = np.digitize(t % 2, x_bins) - 1
|
||||||
|
y = np.digitize(val, y_bins[i]) - 1
|
||||||
|
eye_data[i][y][x] += 1
|
||||||
|
|
||||||
|
cmap = LinearSegmentedColormap.from_list(
|
||||||
|
"eyemap",
|
||||||
|
[
|
||||||
|
(0, "white"),
|
||||||
|
(0.001, "dodgerblue"),
|
||||||
|
(0.1, "blue"),
|
||||||
|
(0.2, "cyan"),
|
||||||
|
(0.5, "lime"),
|
||||||
|
(0.8, "gold"),
|
||||||
|
(1, "red"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# ordering = np.argsort(timestamps)
|
||||||
|
# signals = [signal[ordering] for signal in signals]
|
||||||
|
# timestamps = timestamps[ordering]
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
|
||||||
|
fig.set_figwidth(18)
|
||||||
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||||
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
|
# xaxis = timestamps / sps
|
||||||
for j, (label, signal) in enumerate(zip(labels, signals)):
|
# xaxis = np.arange(2 * sps) / sps
|
||||||
|
for j, label in enumerate(labels):
|
||||||
|
x = eye_data[2 * j]
|
||||||
|
y = eye_data[2 * j + 1]
|
||||||
|
# x, y = signal.T
|
||||||
# signal = signal.cpu().numpy()
|
# signal = signal.cpu().numpy()
|
||||||
for i in range(len(signal) // sps - 1):
|
# for i in range(len(signal) // sps - 1):
|
||||||
x, y = signal[i * sps : (i + 2) * sps].T
|
# x, y = signal[i * sps : (i + 2) * sps].T
|
||||||
axs[0, j].plot(xaxis, np.abs(x) ** 2, color="C0", alpha=0.02)
|
# axs[0 + 2 * j].scatter((timestamps/sps) % 2, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1)
|
||||||
axs[1, j].plot(xaxis, np.abs(y) ** 2, color="C0", alpha=0.02)
|
# axs[1 + 2 * j].scatter((timestamps/sps) % 2, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1)
|
||||||
axs[0, j].set_title(label + " x")
|
axs[0 + 2 * j].imshow(
|
||||||
axs[1, j].set_title(label + " y")
|
x, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j][0], y_bins[2 * j][-1]]
|
||||||
axs[0, j].set_xlabel("Symbol")
|
)
|
||||||
axs[1, j].set_xlabel("Symbol")
|
axs[1 + 2 * j].imshow(
|
||||||
axs[0, j].set_ylabel("normalized power")
|
y, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j + 1][0], y_bins[2 * j + 1][-1]]
|
||||||
axs[1, j].set_ylabel("normalized power")
|
)
|
||||||
|
axs[0 + 2 * j].set_xlim((x_bins[0], x_bins[-1]))
|
||||||
|
axs[1 + 2 * j].set_xlim((x_bins[0], x_bins[-1]))
|
||||||
|
ymin = np.min(y_bins[:, 0])
|
||||||
|
ymax = np.max(y_bins[:, -1])
|
||||||
|
ydiff = ymax - ymin
|
||||||
|
axs[0 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff))
|
||||||
|
axs[1 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff))
|
||||||
|
axs[0 + 2 * j].set_title(label + " x")
|
||||||
|
axs[1 + 2 * j].set_title(label + " y")
|
||||||
|
axs[0 + 2 * j].set_xlabel("Symbol")
|
||||||
|
axs[1 + 2 * j].set_xlabel("Symbol")
|
||||||
|
axs[0 + 2 * j].set_box_aspect(1)
|
||||||
|
axs[1 + 2 * j].set_box_aspect(1)
|
||||||
|
axs[0].set_ylabel("normalized power")
|
||||||
|
fig.tight_layout()
|
||||||
|
# axs[1+2*len(labels)-1].set_ylabel("normalized power")
|
||||||
|
|
||||||
if show:
|
if show:
|
||||||
plt.show()
|
plt.show()
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
def _plot_model_response_head(
|
||||||
|
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
|
||||||
|
):
|
||||||
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||||
labels = [labels]
|
labels = [labels]
|
||||||
else:
|
else:
|
||||||
@@ -657,19 +759,31 @@ class HyperTraining:
|
|||||||
if not any(labels):
|
if not any(labels):
|
||||||
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
||||||
|
|
||||||
|
ordering = np.argsort(timestamps)
|
||||||
|
signals = [signal[ordering] for signal in signals]
|
||||||
|
timestamps = timestamps[ordering]
|
||||||
|
|
||||||
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
|
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
|
||||||
fig.set_size_inches(18, 6)
|
fig.set_figwidth(18)
|
||||||
|
fig.set_figheight(4)
|
||||||
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||||
for i, ax in enumerate(axs):
|
for i, ax in enumerate(axs):
|
||||||
|
ax: plt.Axes
|
||||||
for signal, label in zip(signals, labels):
|
for signal, label in zip(signals, labels):
|
||||||
if sps is not None:
|
if sps is not None:
|
||||||
xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False)
|
xaxis = timestamps / sps
|
||||||
else:
|
else:
|
||||||
xaxis = np.arange(len(signal))
|
xaxis = timestamps
|
||||||
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
|
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
|
||||||
ax.set_xlabel("Sample" if sps is None else "Symbol")
|
ax.set_xlabel("Sample" if sps is None else "Symbol")
|
||||||
ax.set_ylabel("normalized power")
|
ax.set_ylabel("normalized power")
|
||||||
|
ax.minorticks_on()
|
||||||
|
ax.tick_params(axis="y", which="minor", left=False, right=False)
|
||||||
|
ax.grid(which="major", axis="x")
|
||||||
|
ax.grid(which="minor", axis="x", linestyle=":")
|
||||||
|
ax.grid(which="major", axis="y")
|
||||||
ax.legend(loc="upper right")
|
ax.legend(loc="upper right")
|
||||||
|
fig.tight_layout()
|
||||||
if show:
|
if show:
|
||||||
plt.show()
|
plt.show()
|
||||||
return fig
|
return fig
|
||||||
@@ -680,22 +794,52 @@ class HyperTraining:
|
|||||||
model=None,
|
model=None,
|
||||||
title_append="",
|
title_append="",
|
||||||
subtitle="",
|
subtitle="",
|
||||||
mode: Literal["eye", "head"] = "head",
|
mode: Literal["eye", "head", "powers"] = "head",
|
||||||
show=True,
|
show=False,
|
||||||
):
|
):
|
||||||
|
if mode == "powers":
|
||||||
|
input_data = torch.ones(
|
||||||
|
1, 2 * self.data_settings.output_size, dtype=getattr(torch, self.data_settings.dtype)
|
||||||
|
).to(self.pytorch_settings.device)
|
||||||
|
model = model.to(self.pytorch_settings.device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
_, powers = model(input_data, trace_powers=True)
|
||||||
|
|
||||||
|
powers = [power.item() for power in powers]
|
||||||
|
layer_names = ["input", *[str(x).split("(")[0] for x in model._layers._modules.values()]]
|
||||||
|
|
||||||
|
# remove dropout layers
|
||||||
|
mask = [1 if "Dropout" not in layer_name else 0 for layer_name in layer_names]
|
||||||
|
layer_names = [layer_name for layer_name, m in zip(layer_names, mask) if m]
|
||||||
|
powers = [power for power, m in zip(powers, mask) if m]
|
||||||
|
|
||||||
|
fig = self._plot_model_response_powers(
|
||||||
|
powers, layer_names, title_append=title_append, subtitle=subtitle, show=show
|
||||||
|
)
|
||||||
|
return fig
|
||||||
|
|
||||||
data_settings_backup = copy.deepcopy(self.data_settings)
|
data_settings_backup = copy.deepcopy(self.data_settings)
|
||||||
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
||||||
self.data_settings.drop_first = 100*128
|
self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
|
||||||
self.data_settings.shuffle = False
|
self.data_settings.shuffle = False
|
||||||
self.data_settings.train_split = 1.0
|
self.data_settings.train_split = 1.0
|
||||||
self.pytorch_settings.batchsize = (
|
self.pytorch_settings.batchsize = (
|
||||||
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
|
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
|
||||||
)
|
)
|
||||||
plot_loader, _ = self.get_sliced_data(trial, override={"num_symbols": self.pytorch_settings.batchsize})
|
config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path
|
||||||
|
fiber_length = int(float(str(config_path).split('-')[-7])/1000)
|
||||||
|
plot_loader, _ = self.get_sliced_data(
|
||||||
|
trial,
|
||||||
|
override={
|
||||||
|
"num_symbols": self.pytorch_settings.batchsize,
|
||||||
|
"config_path": config_path,
|
||||||
|
}
|
||||||
|
)
|
||||||
self.data_settings = data_settings_backup
|
self.data_settings = data_settings_backup
|
||||||
self.pytorch_settings = pytorch_settings_backup
|
self.pytorch_settings = pytorch_settings_backup
|
||||||
|
|
||||||
fiber_in, fiber_out, regen = self.run_model(model, plot_loader)
|
fiber_in, fiber_out, regen, timestamps = self.run_model(model, plot_loader)
|
||||||
fiber_in = fiber_in.view(-1, 2)
|
fiber_in = fiber_in.view(-1, 2)
|
||||||
fiber_out = fiber_out.view(-1, 2)
|
fiber_out = fiber_out.view(-1, 2)
|
||||||
regen = regen.view(-1, 2)
|
regen = regen.view(-1, 2)
|
||||||
@@ -703,6 +847,7 @@ class HyperTraining:
|
|||||||
fiber_in = fiber_in.numpy()
|
fiber_in = fiber_in.numpy()
|
||||||
fiber_out = fiber_out.numpy()
|
fiber_out = fiber_out.numpy()
|
||||||
regen = regen.numpy()
|
regen = regen.numpy()
|
||||||
|
timestamps = timestamps.numpy()
|
||||||
|
|
||||||
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
|
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
|
||||||
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
|
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
|
||||||
@@ -713,9 +858,10 @@ class HyperTraining:
|
|||||||
fiber_in,
|
fiber_in,
|
||||||
fiber_out,
|
fiber_out,
|
||||||
regen,
|
regen,
|
||||||
|
timestamps=timestamps,
|
||||||
labels=("fiber in", "fiber out", "regen"),
|
labels=("fiber in", "fiber out", "regen"),
|
||||||
sps=plot_loader.dataset.samples_per_symbol,
|
sps=plot_loader.dataset.samples_per_symbol,
|
||||||
title_append=title_append,
|
title_append=title_append + f" ({fiber_length} km)",
|
||||||
subtitle=subtitle,
|
subtitle=subtitle,
|
||||||
show=show,
|
show=show,
|
||||||
)
|
)
|
||||||
@@ -725,9 +871,10 @@ class HyperTraining:
|
|||||||
fiber_in,
|
fiber_in,
|
||||||
fiber_out,
|
fiber_out,
|
||||||
regen,
|
regen,
|
||||||
|
timestamps=timestamps,
|
||||||
labels=("fiber in", "fiber out", "regen"),
|
labels=("fiber in", "fiber out", "regen"),
|
||||||
sps=plot_loader.dataset.samples_per_symbol,
|
sps=plot_loader.dataset.samples_per_symbol,
|
||||||
title_append=title_append,
|
title_append=title_append + f" ({fiber_length} km)",
|
||||||
subtitle=subtitle,
|
subtitle=subtitle,
|
||||||
show=show,
|
show=show,
|
||||||
)
|
)
|
||||||
@@ -739,7 +886,7 @@ class HyperTraining:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build_title(trial: optuna.trial.Trial):
|
def build_title(trial: optuna.trial.Trial):
|
||||||
title_append = f"for trial {trial.number}"
|
title_append = f"at epoch {trial.user_attrs.get("epoch", -1)} for trial {trial.number}"
|
||||||
model_n_hidden_layers = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_n_hidden_layers", 0)
|
model_n_hidden_layers = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_n_hidden_layers", 0)
|
||||||
input_dim = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_input_dim", 0)
|
input_dim = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_input_dim", 0)
|
||||||
model_dims = [
|
model_dims = [
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import matplotlib.pyplot as plt
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
# import torch.nn as nn
|
||||||
|
|
||||||
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
|
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
@@ -47,88 +47,6 @@ from .settings import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class regenerator(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*dims,
|
|
||||||
layer_function=util.complexNN.ONN,
|
|
||||||
layer_kwargs: dict | None = None,
|
|
||||||
layer_parametrizations: list[dict] = None,
|
|
||||||
# [
|
|
||||||
# {
|
|
||||||
# "tensor_name": "weight",
|
|
||||||
# "parametrization": util.complexNN.Unitary,
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "tensor_name": "scale",
|
|
||||||
# "parametrization": util.complexNN.Clamp,
|
|
||||||
# },
|
|
||||||
# ],
|
|
||||||
activation_function=util.complexNN.Pow,
|
|
||||||
dtype=torch.float64,
|
|
||||||
dropout_prob=0.01,
|
|
||||||
scale=False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super(regenerator, self).__init__()
|
|
||||||
if len(dims) == 0:
|
|
||||||
try:
|
|
||||||
dims = kwargs["dims"]
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError("dims must be provided")
|
|
||||||
self._n_hidden_layers = len(dims) - 2
|
|
||||||
self._layers = nn.Sequential()
|
|
||||||
if layer_kwargs is None:
|
|
||||||
layer_kwargs = {}
|
|
||||||
# self.powers = []
|
|
||||||
|
|
||||||
for i in range(self._n_hidden_layers + 1):
|
|
||||||
if scale:
|
|
||||||
self._layers.append(util.complexNN.Scale(dims[i]))
|
|
||||||
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs))
|
|
||||||
if i < self._n_hidden_layers:
|
|
||||||
if dropout_prob is not None:
|
|
||||||
self._layers.append(util.complexNN.DropoutComplex(p=dropout_prob))
|
|
||||||
self._layers.append(activation_function(bias=True, size=dims[i + 1]))
|
|
||||||
|
|
||||||
self._layers.append(util.complexNN.Scale(dims[-1]))
|
|
||||||
|
|
||||||
# add parametrizations
|
|
||||||
if layer_parametrizations is not None:
|
|
||||||
for layer in self._layers:
|
|
||||||
for layer_parametrization in layer_parametrizations:
|
|
||||||
tensor_name = layer_parametrization.get("tensor_name", None)
|
|
||||||
parametrization = layer_parametrization.get("parametrization", None)
|
|
||||||
param_kwargs = layer_parametrization.get("kwargs", {})
|
|
||||||
if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None:
|
|
||||||
parametrization(layer, tensor_name, **param_kwargs)
|
|
||||||
|
|
||||||
# def __call__(self, input_x, **kwargs):
|
|
||||||
# return self.forward(input_x, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, input_x, trace_powers=False):
|
|
||||||
x = input_x
|
|
||||||
|
|
||||||
if trace_powers:
|
|
||||||
powers = [x.abs().square().sum()]
|
|
||||||
|
|
||||||
# check if tracing
|
|
||||||
if torch.jit.is_tracing():
|
|
||||||
for layer in self._layers:
|
|
||||||
x = layer(x)
|
|
||||||
if trace_powers:
|
|
||||||
powers.append(x.abs().square().sum())
|
|
||||||
else:
|
|
||||||
# with torch.nn.utils.parametrize.cached():
|
|
||||||
for layer in self._layers:
|
|
||||||
x = layer(x)
|
|
||||||
if trace_powers:
|
|
||||||
powers.append(x.abs().square().sum())
|
|
||||||
if trace_powers:
|
|
||||||
return x, powers
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def traverse_dict_update(target, source):
|
def traverse_dict_update(target, source):
|
||||||
for k, v in source.items():
|
for k, v in source.items():
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
@@ -164,7 +82,7 @@ class Trainer:
|
|||||||
ModelSettings,
|
ModelSettings,
|
||||||
OptimizerSettings,
|
OptimizerSettings,
|
||||||
PytorchSettings,
|
PytorchSettings,
|
||||||
regenerator,
|
util.complexNN.regenerator,
|
||||||
torch.nn.utils.parametrizations.orthogonal,
|
torch.nn.utils.parametrizations.orthogonal,
|
||||||
])
|
])
|
||||||
if self.resume:
|
if self.resume:
|
||||||
@@ -264,7 +182,7 @@ class Trainer:
|
|||||||
dtype = self.model_kwargs["dtype"]
|
dtype = self.model_kwargs["dtype"]
|
||||||
|
|
||||||
# dims = self.model_kwargs.pop("dims")
|
# dims = self.model_kwargs.pop("dims")
|
||||||
self.model = regenerator(**self.model_kwargs)
|
self.model = util.complexNN.regenerator(**self.model_kwargs)
|
||||||
|
|
||||||
if self.writer is not None:
|
if self.writer is not None:
|
||||||
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
|
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
|
||||||
@@ -364,7 +282,7 @@ class Trainer:
|
|||||||
task = progress.add_task("-.---e--", total=len(train_loader))
|
task = progress.add_task("-.---e--", total=len(train_loader))
|
||||||
progress.start()
|
progress.start()
|
||||||
|
|
||||||
# running_loss2 = 0.0
|
running_loss2 = 0.0
|
||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
self.model.train()
|
self.model.train()
|
||||||
loader_len = len(train_loader)
|
loader_len = len(train_loader)
|
||||||
@@ -379,23 +297,24 @@ class Trainer:
|
|||||||
loss_value = loss.item()
|
loss_value = loss.item()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# running_loss2 += loss_value
|
running_loss2 += loss_value
|
||||||
running_loss += loss_value
|
running_loss += loss_value
|
||||||
|
|
||||||
if enable_progress:
|
if enable_progress:
|
||||||
progress.update(task, advance=1, description=f"{running_loss/(batch_idx+1):.3e}")
|
progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
||||||
|
|
||||||
if batch_idx % self.pytorch_settings.write_every == 0:
|
if batch_idx % self.pytorch_settings.write_every == 0:
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"training loss",
|
"training loss",
|
||||||
running_loss / (batch_idx + 1),
|
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
||||||
epoch * loader_len + batch_idx,
|
epoch + batch_idx/loader_len,
|
||||||
)
|
)
|
||||||
|
running_loss2 = 0.0
|
||||||
|
|
||||||
if enable_progress:
|
if enable_progress:
|
||||||
progress.stop()
|
progress.stop()
|
||||||
|
|
||||||
return running_loss / (batch_idx + 1)
|
return running_loss / len(train_loader)
|
||||||
|
|
||||||
def eval_model(self, valid_loader, epoch, enable_progress=True):
|
def eval_model(self, valid_loader, epoch, enable_progress=True):
|
||||||
if enable_progress:
|
if enable_progress:
|
||||||
@@ -418,7 +337,7 @@ class Trainer:
|
|||||||
self.model.eval()
|
self.model.eval()
|
||||||
running_error = 0
|
running_error = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, (x, y, _) in enumerate(valid_loader):
|
for _, (x, y, _) in enumerate(valid_loader):
|
||||||
x, y = (
|
x, y = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
@@ -429,9 +348,9 @@ class Trainer:
|
|||||||
running_error += error_value
|
running_error += error_value
|
||||||
|
|
||||||
if enable_progress:
|
if enable_progress:
|
||||||
progress.update(task, advance=1, description=f"{error_value/(batch_idx+1):.3e}")
|
progress.update(task, advance=1, description=f"{error_value:.3e}")
|
||||||
|
|
||||||
running_error /= (batch_idx+1)
|
running_error = running_error/len(valid_loader)
|
||||||
|
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"eval loss",
|
"eval loss",
|
||||||
@@ -858,7 +777,7 @@ class Trainer:
|
|||||||
self.pytorch_settings.batchsize = (
|
self.pytorch_settings.batchsize = (
|
||||||
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
|
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
|
||||||
)
|
)
|
||||||
config_path = random.choice(self.data_settings.config_path)
|
config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path
|
||||||
fiber_length = int(float(str(config_path).split('-')[-7])/1000)
|
fiber_length = int(float(str(config_path).split('-')[-7])/1000)
|
||||||
plot_loader, _ = self.get_sliced_data(
|
plot_loader, _ = self.get_sliced_data(
|
||||||
override={
|
override={
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from pathlib import Path
|
||||||
import matplotlib
|
import matplotlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -22,8 +23,8 @@ global_settings = GlobalSettings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
data_settings = DataSettings(
|
data_settings = DataSettings(
|
||||||
# config_path="data/*-128-16384-50000-0-0-17-0-PAM4-0.ini",
|
config_path="data/20241204-13*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||||
config_path=[f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in (40000, 50000, 60000)],
|
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
||||||
dtype="complex64",
|
dtype="complex64",
|
||||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||||
symbols=13, # study: single_core_regen_20241123_011232
|
symbols=13, # study: single_core_regen_20241123_011232
|
||||||
@@ -52,8 +53,8 @@ model_settings = ModelSettings(
|
|||||||
output_dim=2,
|
output_dim=2,
|
||||||
n_hidden_layers=4,
|
n_hidden_layers=4,
|
||||||
overrides={
|
overrides={
|
||||||
"n_hidden_nodes_0": 4,
|
"n_hidden_nodes_0": 8,
|
||||||
"n_hidden_nodes_1": 4,
|
"n_hidden_nodes_1": 8,
|
||||||
"n_hidden_nodes_2": 4,
|
"n_hidden_nodes_2": 4,
|
||||||
"n_hidden_nodes_3": 4,
|
"n_hidden_nodes_3": 4,
|
||||||
},
|
},
|
||||||
@@ -61,7 +62,7 @@ model_settings = ModelSettings(
|
|||||||
dropout_prob=0.01,
|
dropout_prob=0.01,
|
||||||
model_layer_function="ONNRect",
|
model_layer_function="ONNRect",
|
||||||
model_layer_kwargs={"square": True},
|
model_layer_kwargs={"square": True},
|
||||||
scale=True,
|
scale=False,
|
||||||
model_layer_parametrizations=[
|
model_layer_parametrizations=[
|
||||||
{
|
{
|
||||||
"tensor_name": "weight",
|
"tensor_name": "weight",
|
||||||
@@ -113,7 +114,7 @@ model_settings = ModelSettings(
|
|||||||
optimizer_settings = OptimizerSettings(
|
optimizer_settings = OptimizerSettings(
|
||||||
optimizer="AdamW",
|
optimizer="AdamW",
|
||||||
optimizer_kwargs={
|
optimizer_kwargs={
|
||||||
"lr": 0.05,
|
"lr": 0.01,
|
||||||
"amsgrad": True,
|
"amsgrad": True,
|
||||||
# "weight_decay": 1e-7,
|
# "weight_decay": 1e-7,
|
||||||
},
|
},
|
||||||
@@ -142,8 +143,9 @@ def save_dict_to_file(dictionary, filename):
|
|||||||
json.dump(dictionary, f, indent=4)
|
json.dump(dictionary, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
def sweep_lengths(*lengths, model=None):
|
def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
|
||||||
assert model is not None, "Model must be provided."
|
assert model is not None, "Model must be provided."
|
||||||
|
assert data_glob is not None, "Data glob must be provided."
|
||||||
model = model
|
model = model
|
||||||
|
|
||||||
fiber_ins = {}
|
fiber_ins = {}
|
||||||
@@ -151,19 +153,31 @@ def sweep_lengths(*lengths, model=None):
|
|||||||
regens = {}
|
regens = {}
|
||||||
timestampss = {}
|
timestampss = {}
|
||||||
|
|
||||||
for length in lengths:
|
trainer = Trainer(
|
||||||
trainer = Trainer(
|
|
||||||
checkpoint_path=model,
|
checkpoint_path=model,
|
||||||
settings_override={
|
|
||||||
"data_settings": {
|
|
||||||
"config_path": f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini",
|
|
||||||
"train_split": 1,
|
|
||||||
"shuffle": True,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
trainer.define_model()
|
trainer.define_model()
|
||||||
loader, _ = trainer.get_sliced_data()
|
|
||||||
|
for length in lengths:
|
||||||
|
data_glob_length = data_glob.replace("{length}", str(length))
|
||||||
|
files = list(Path.cwd().glob(data_glob_length))
|
||||||
|
if len(files) == 0:
|
||||||
|
continue
|
||||||
|
if strategy == "newest":
|
||||||
|
sorted_kwargs = {
|
||||||
|
'key': lambda x: x.stat().st_mtime,
|
||||||
|
'reverse': True,
|
||||||
|
}
|
||||||
|
elif strategy == "oldest":
|
||||||
|
sorted_kwargs = {
|
||||||
|
'key': lambda x: x.stat().st_mtime,
|
||||||
|
'reverse': False,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown strategy {strategy}.")
|
||||||
|
file = sorted(files, **sorted_kwargs)[0]
|
||||||
|
|
||||||
|
loader, _ = trainer.get_sliced_data(override={"config_path": file})
|
||||||
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
|
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
|
||||||
|
|
||||||
fiber_ins[length] = fiber_in
|
fiber_ins[length] = fiber_in
|
||||||
@@ -171,17 +185,23 @@ def sweep_lengths(*lengths, model=None):
|
|||||||
regens[length] = regen
|
regens[length] = regen
|
||||||
timestampss[length] = timestamps
|
timestampss[length] = timestamps
|
||||||
|
|
||||||
data = torch.zeros(2 * len(lengths), 2, fiber_out.shape[0])
|
data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0])
|
||||||
channel_names = ["" for _ in range(2 * len(lengths))]
|
channel_names = ["" for _ in range(2 * len(timestampss.keys())+2)]
|
||||||
|
|
||||||
for li, length in enumerate(lengths):
|
data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128
|
||||||
data[2 * li, 0, :] = timestampss[length] / 128
|
data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square()
|
||||||
data[2 * li, 1, :] = regens[length][:, 0].abs().square()
|
|
||||||
data[2 * li + 1, 0, :] = timestampss[length] / 128
|
|
||||||
data[2 * li + 1, 1, :] = regens[length][:, 1].abs().square()
|
|
||||||
|
|
||||||
channel_names[2 * li] = f"regen x {length}"
|
channel_names[1] = "fiber in x"
|
||||||
channel_names[2 * li + 1] = f"regen y {length}"
|
|
||||||
|
|
||||||
|
for li, length in enumerate(timestampss.keys()):
|
||||||
|
data[2+2 * li, 0, :] = timestampss[length] / 128
|
||||||
|
data[2+2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
|
||||||
|
data[2+2 * li + 1, 0, :] = timestampss[length] / 128
|
||||||
|
data[2+2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
|
||||||
|
|
||||||
|
channel_names[2+2 * li+1] = f"regen x {length}"
|
||||||
|
channel_names[2+2 * li] = f"fiber out x {length}"
|
||||||
|
|
||||||
# get current backend
|
# get current backend
|
||||||
backend = matplotlib.get_backend()
|
backend = matplotlib.get_backend()
|
||||||
@@ -189,28 +209,30 @@ def sweep_lengths(*lengths, model=None):
|
|||||||
matplotlib.use("TkCairo")
|
matplotlib.use("TkCairo")
|
||||||
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
|
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
|
||||||
|
|
||||||
print_attrs = ("channel", "success", "min_area")
|
print_attrs = ("channel_name", "success", "min_area")
|
||||||
with np.printoptions(precision=3, suppress=True, formatter={'float': '{:0.3e}'.format}):
|
with np.printoptions(precision=3, suppress=True, formatter={'float': '{:0.3e}'.format}):
|
||||||
for result in eye.eye_stats:
|
for result in eye.eye_stats:
|
||||||
print_dict = {attr: result[attr] for attr in print_attrs}
|
print_dict = {attr: result[attr] for attr in print_attrs}
|
||||||
rprint(print_dict)
|
rprint(print_dict)
|
||||||
rprint()
|
rprint()
|
||||||
|
|
||||||
eye.plot()
|
eye.plot(all_stats=False)
|
||||||
matplotlib.use(backend)
|
matplotlib.use(backend)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
# sweep_lengths(30000, 40000, 50000, 60000, 70000, model=".models/best_20241202_143149.tar")
|
lengths = range(90000, 100000+10000, 10000)
|
||||||
|
# lengths = [100000]
|
||||||
|
sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest")
|
||||||
|
|
||||||
trainer = Trainer(
|
# trainer = Trainer(
|
||||||
global_settings=global_settings,
|
# global_settings=global_settings,
|
||||||
data_settings=data_settings,
|
# data_settings=data_settings,
|
||||||
pytorch_settings=pytorch_settings,
|
# pytorch_settings=pytorch_settings,
|
||||||
model_settings=model_settings,
|
# model_settings=model_settings,
|
||||||
optimizer_settings=optimizer_settings,
|
# optimizer_settings=optimizer_settings,
|
||||||
# checkpoint_path=".models/best_20241202_143149.tar",
|
# # checkpoint_path=".models/best_20241202_143149.tar",
|
||||||
# 20241202_143149
|
# # 20241202_143149
|
||||||
)
|
# )
|
||||||
trainer.train()
|
# trainer.train()
|
||||||
@@ -39,7 +39,7 @@ import numpy as np
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
dataset = FiberRegenerationDataset("data/20241115-175517-128-16384-10000-0-0-17-0-PAM4-0.ini", symbols=13, drop_first=100, output_dim=26, num_symbols=100)
|
dataset = FiberRegenerationDataset("data/202412*-128-16384-50000-0-0-17-0-PAM4-0.ini", symbols=13, drop_first=100, output_dim=26, num_symbols=100)
|
||||||
|
|
||||||
loader = DataLoader(dataset, batch_size=10, shuffle=True)
|
loader = DataLoader(dataset, batch_size=10, shuffle=True)
|
||||||
|
|
||||||
|
|||||||
@@ -569,6 +569,78 @@ class ZReLU(nn.Module):
|
|||||||
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2)
|
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2)
|
||||||
else:
|
else:
|
||||||
return torch.relu(x)
|
return torch.relu(x)
|
||||||
|
|
||||||
|
|
||||||
|
class regenerator(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*dims,
|
||||||
|
layer_function=ONN,
|
||||||
|
layer_kwargs: dict | None = None,
|
||||||
|
layer_parametrizations: list[dict] = None,
|
||||||
|
activation_function=Pow,
|
||||||
|
dtype=torch.float64,
|
||||||
|
dropout_prob=0.01,
|
||||||
|
scale=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super(regenerator, self).__init__()
|
||||||
|
if len(dims) == 0:
|
||||||
|
try:
|
||||||
|
dims = kwargs["dims"]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError("dims must be provided")
|
||||||
|
self._n_hidden_layers = len(dims) - 2
|
||||||
|
self._layers = nn.Sequential()
|
||||||
|
if layer_kwargs is None:
|
||||||
|
layer_kwargs = {}
|
||||||
|
# self.powers = []
|
||||||
|
|
||||||
|
for i in range(self._n_hidden_layers + 1):
|
||||||
|
if scale:
|
||||||
|
self._layers.append(Scale(dims[i]))
|
||||||
|
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs))
|
||||||
|
if i < self._n_hidden_layers:
|
||||||
|
if dropout_prob is not None:
|
||||||
|
self._layers.append(DropoutComplex(p=dropout_prob))
|
||||||
|
self._layers.append(activation_function(bias=True, size=dims[i + 1]))
|
||||||
|
|
||||||
|
self._layers.append(Scale(dims[-1]))
|
||||||
|
|
||||||
|
# add parametrizations
|
||||||
|
if layer_parametrizations is not None:
|
||||||
|
for layer in self._layers:
|
||||||
|
for layer_parametrization in layer_parametrizations:
|
||||||
|
tensor_name = layer_parametrization.get("tensor_name", None)
|
||||||
|
parametrization = layer_parametrization.get("parametrization", None)
|
||||||
|
param_kwargs = layer_parametrization.get("kwargs", {})
|
||||||
|
if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None:
|
||||||
|
parametrization(layer, tensor_name, **param_kwargs)
|
||||||
|
|
||||||
|
# def __call__(self, input_x, **kwargs):
|
||||||
|
# return self.forward(input_x, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, input_x, trace_powers=False):
|
||||||
|
x = input_x
|
||||||
|
|
||||||
|
if trace_powers:
|
||||||
|
powers = [x.abs().square().sum()]
|
||||||
|
|
||||||
|
# check if tracing
|
||||||
|
if torch.jit.is_tracing():
|
||||||
|
for layer in self._layers:
|
||||||
|
x = layer(x)
|
||||||
|
if trace_powers:
|
||||||
|
powers.append(x.abs().square().sum())
|
||||||
|
else:
|
||||||
|
# with torch.nn.utils.parametrize.cached():
|
||||||
|
for layer in self._layers:
|
||||||
|
x = layer(x)
|
||||||
|
if trace_powers:
|
||||||
|
powers.append(x.abs().square().sum())
|
||||||
|
if trace_powers:
|
||||||
|
return x, powers
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from matplotlib.colors import LinearSegmentedColormap
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.cluster.vq import kmeans2
|
from scipy.cluster.vq import kmeans2
|
||||||
import warnings
|
import warnings
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from rich import pretty
|
from rich import pretty
|
||||||
@@ -67,7 +68,7 @@ def generate_wavelet(sps, oversample=3):
|
|||||||
|
|
||||||
|
|
||||||
class eye_diagram:
|
class eye_diagram:
|
||||||
def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4):
|
def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4, multithreaded=True):
|
||||||
# data has shape [channels, 2, samples]
|
# data has shape [channels, 2, samples]
|
||||||
# each sample has a timestamp and a value
|
# each sample has a timestamp and a value
|
||||||
if data.ndim == 2:
|
if data.ndim == 2:
|
||||||
@@ -79,28 +80,38 @@ class eye_diagram:
|
|||||||
self.eye_stats = [{"success": False} for _ in range(self.channels)]
|
self.eye_stats = [{"success": False} for _ in range(self.channels)]
|
||||||
self.horizontal_bins = horizontal_bins
|
self.horizontal_bins = horizontal_bins
|
||||||
self.vertical_bins = vertical_bins
|
self.vertical_bins = vertical_bins
|
||||||
|
self.multi_threaded = multithreaded
|
||||||
self.eye_built = False
|
self.eye_built = False
|
||||||
self.analyse(self.n_levels)
|
self.analyse()
|
||||||
|
|
||||||
def generate_eye_data(self):
|
def generate_eye_data(self):
|
||||||
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
|
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
|
||||||
self.y_bins = np.zeros((self.channels, self.vertical_bins))
|
self.y_bins = np.zeros((self.channels, self.vertical_bins))
|
||||||
self.eye_data = np.zeros((self.channels, self.vertical_bins, self.horizontal_bins))
|
self.eye_data = np.zeros((self.channels, self.vertical_bins, self.horizontal_bins))
|
||||||
for i in range(self.channels):
|
datas = [self.raw_data[i] for i in range(self.channels)]
|
||||||
data_min = np.min(self.raw_data[i, 1, :])
|
if self.multi_threaded:
|
||||||
data_max = np.max(self.raw_data[i, 1, :])
|
with multiprocessing.Pool() as pool:
|
||||||
self.y_bins[i] = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
|
results = pool.map(self.generate_eye_data_single, datas)
|
||||||
|
for i, result in enumerate(results):
|
||||||
t_vals = self.raw_data[i, 0, :] % 2
|
self.eye_data[i], self.y_bins[i] = result
|
||||||
val_vals = self.raw_data[i, 1, :]
|
else:
|
||||||
|
for i, data in enumerate(datas):
|
||||||
x_indices = np.digitize(t_vals, self.x_bins) - 1
|
self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
|
||||||
y_indices = np.digitize(val_vals, self.y_bins[i]) - 1
|
|
||||||
|
|
||||||
np.add.at(self.eye_data[i], (y_indices, x_indices), 1)
|
|
||||||
self.eye_built = True
|
self.eye_built = True
|
||||||
|
|
||||||
|
def generate_eye_data_single(self, data):
|
||||||
|
eye_data = np.zeros((self.vertical_bins, self.horizontal_bins))
|
||||||
|
data_min = np.min(data[1, :])
|
||||||
|
data_max = np.max(data[1, :])
|
||||||
|
y_bins = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
|
||||||
|
t_vals = data[0, :] % 2
|
||||||
|
val_vals = data[1, :]
|
||||||
|
x_indices = np.digitize(t_vals, self.x_bins) - 1
|
||||||
|
y_indices = np.digitize(val_vals, y_bins) - 1
|
||||||
|
np.add.at(eye_data, (y_indices, x_indices), 1)
|
||||||
|
return eye_data, y_bins
|
||||||
|
|
||||||
def plot(self, title="Eye Diagram", stats=True, show=True):
|
def plot(self, title="Eye Diagram", stats=True, all_stats=True, show=True):
|
||||||
if not self.eye_built:
|
if not self.eye_built:
|
||||||
self.generate_eye_data()
|
self.generate_eye_data()
|
||||||
cmap = LinearSegmentedColormap.from_list(
|
cmap = LinearSegmentedColormap.from_list(
|
||||||
@@ -118,8 +129,10 @@ class eye_diagram:
|
|||||||
ax = np.atleast_1d(ax).transpose().flatten()
|
ax = np.atleast_1d(ax).transpose().flatten()
|
||||||
for i in range(self.channels):
|
for i in range(self.channels):
|
||||||
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}")
|
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}")
|
||||||
ax[i].set_xlabel("Symbol")
|
if (i+1) % rows == 0:
|
||||||
ax[i].set_ylabel("Amplitude")
|
ax[i].set_xlabel("Symbol")
|
||||||
|
if i < rows:
|
||||||
|
ax[i].set_ylabel("Amplitude")
|
||||||
ax[i].grid()
|
ax[i].grid()
|
||||||
ax[i].imshow(
|
ax[i].imshow(
|
||||||
self.eye_data[i],
|
self.eye_data[i],
|
||||||
@@ -134,67 +147,6 @@ class eye_diagram:
|
|||||||
yspan = ymax - ymin
|
yspan = ymax - ymin
|
||||||
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
|
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
|
||||||
if stats and self.eye_stats[i]["success"]:
|
if stats and self.eye_stats[i]["success"]:
|
||||||
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
|
|
||||||
ax[i].set_yticks(self.eye_stats[i]["levels"])
|
|
||||||
# add arrows for amplitudes
|
|
||||||
for j in range(len(self.eye_stats[i]["amplitudes"])):
|
|
||||||
ax[i].annotate(
|
|
||||||
"",
|
|
||||||
xy=(0.05, self.eye_stats[i]["levels"][j]),
|
|
||||||
xytext=(0.05, self.eye_stats[i]["levels"][j + 1]),
|
|
||||||
arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
|
||||||
)
|
|
||||||
ax[i].annotate(
|
|
||||||
f"{self.eye_stats[i]['amplitudes'][j]:.2e}",
|
|
||||||
xy=(0.06, (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2),
|
|
||||||
)
|
|
||||||
# add arrows for eye heights
|
|
||||||
for j in range(len(self.eye_stats[i]["heights"])):
|
|
||||||
try:
|
|
||||||
bot = np.max(self.eye_stats[i]["amplitude_clusters"][j])
|
|
||||||
top = np.min(self.eye_stats[i]["amplitude_clusters"][j + 1])
|
|
||||||
|
|
||||||
ax[i].annotate(
|
|
||||||
"",
|
|
||||||
xy=(self.eye_stats[i]["time_midpoint"], bot),
|
|
||||||
xytext=(self.eye_stats[i]["time_midpoint"], top),
|
|
||||||
arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
|
||||||
)
|
|
||||||
ax[i].annotate(
|
|
||||||
f"{self.eye_stats[i]['heights'][j]:.2e}",
|
|
||||||
xy=(self.eye_stats[i]["time_midpoint"] + 0.015, (bot + top) / 2 + 0.04),
|
|
||||||
)
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
pass
|
|
||||||
# add arrows for eye widths
|
|
||||||
for j in range(len(self.eye_stats[i]["widths"])):
|
|
||||||
try:
|
|
||||||
left = np.max(self.eye_stats[i]["time_clusters"][j][0])
|
|
||||||
right = np.min(self.eye_stats[i]["time_clusters"][j][1])
|
|
||||||
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
|
||||||
|
|
||||||
ax[i].annotate(
|
|
||||||
"",
|
|
||||||
xy=(left, vertical),
|
|
||||||
xytext=(right, vertical),
|
|
||||||
arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
|
||||||
)
|
|
||||||
ax[i].annotate(
|
|
||||||
f"{self.eye_stats[i]['widths'][j]:.2e}",
|
|
||||||
xy=((left + right) / 2 - 0.15, vertical + 0.01),
|
|
||||||
)
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# add area
|
|
||||||
for j in range(len(self.eye_stats[i]["areas"])):
|
|
||||||
horizontal = self.eye_stats[i]["time_midpoint"]
|
|
||||||
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
|
||||||
ax[i].annotate(
|
|
||||||
f"{self.eye_stats[i]['areas'][j]:.2e}",
|
|
||||||
xy=(horizontal + 0.035, vertical - 0.07),
|
|
||||||
)
|
|
||||||
|
|
||||||
# add min_area above the plot
|
# add min_area above the plot
|
||||||
ax[i].annotate(
|
ax[i].annotate(
|
||||||
f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
|
f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
|
||||||
@@ -202,62 +154,142 @@ class eye_diagram:
|
|||||||
# xycoords="axes fraction",
|
# xycoords="axes fraction",
|
||||||
ha="left",
|
ha="left",
|
||||||
va="center",
|
va="center",
|
||||||
|
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if all_stats:
|
||||||
|
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
|
||||||
|
ax[i].set_yticks(self.eye_stats[i]["levels"])
|
||||||
|
# add arrows for amplitudes
|
||||||
|
for j in range(len(self.eye_stats[i]["amplitudes"])):
|
||||||
|
ax[i].annotate(
|
||||||
|
"",
|
||||||
|
xy=(0.05, self.eye_stats[i]["levels"][j]),
|
||||||
|
xytext=(0.05, self.eye_stats[i]["levels"][j + 1]),
|
||||||
|
arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
||||||
|
)
|
||||||
|
ax[i].annotate(
|
||||||
|
f"{self.eye_stats[i]['amplitudes'][j]:.2e}",
|
||||||
|
xy=(0.06, (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2),
|
||||||
|
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||||
|
)
|
||||||
|
# add arrows for eye heights
|
||||||
|
for j in range(len(self.eye_stats[i]["heights"])):
|
||||||
|
try:
|
||||||
|
bot = np.max(self.eye_stats[i]["amplitude_clusters"][j])
|
||||||
|
top = np.min(self.eye_stats[i]["amplitude_clusters"][j + 1])
|
||||||
|
|
||||||
|
ax[i].annotate(
|
||||||
|
"",
|
||||||
|
xy=(self.eye_stats[i]["time_midpoint"], bot),
|
||||||
|
xytext=(self.eye_stats[i]["time_midpoint"], top),
|
||||||
|
arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
||||||
|
)
|
||||||
|
ax[i].annotate(
|
||||||
|
f"{self.eye_stats[i]['heights'][j]:.2e}",
|
||||||
|
xy=(self.eye_stats[i]["time_midpoint"] + 0.015, (bot + top) / 2 + 0.04),
|
||||||
|
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||||
|
)
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
# add arrows for eye widths
|
||||||
|
for j in range(len(self.eye_stats[i]["widths"])):
|
||||||
|
try:
|
||||||
|
left = np.max(self.eye_stats[i]["time_clusters"][j][0])
|
||||||
|
right = np.min(self.eye_stats[i]["time_clusters"][j][1])
|
||||||
|
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
||||||
|
|
||||||
|
ax[i].annotate(
|
||||||
|
"",
|
||||||
|
xy=(left, vertical),
|
||||||
|
xytext=(right, vertical),
|
||||||
|
arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
||||||
|
)
|
||||||
|
ax[i].annotate(
|
||||||
|
f"{self.eye_stats[i]['widths'][j]:.2e}",
|
||||||
|
xy=((left + right) / 2 - 0.15, vertical + 0.01),
|
||||||
|
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||||
|
)
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# add area
|
||||||
|
for j in range(len(self.eye_stats[i]["areas"])):
|
||||||
|
horizontal = self.eye_stats[i]["time_midpoint"]
|
||||||
|
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
||||||
|
ax[i].annotate(
|
||||||
|
f"{self.eye_stats[i]['areas'][j]:.2e}",
|
||||||
|
xy=(horizontal + 0.035, vertical - 0.07),
|
||||||
|
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||||
|
)
|
||||||
|
|
||||||
fig.tight_layout()
|
fig.tight_layout()
|
||||||
|
|
||||||
if show:
|
if show:
|
||||||
plt.show()
|
plt.show()
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
def analyse(self, n_levels=4):
|
def analyse_single(self, data, index):
|
||||||
warnings.filterwarnings("error")
|
warnings.filterwarnings("error")
|
||||||
for i in range(self.channels):
|
eye_stats = {}
|
||||||
self.eye_stats[i]["channel"] = str(i+1) if self.channel_names is None else self.channel_names[i]
|
eye_stats["channel_name"] = str(index+1) if self.channel_names is None else self.channel_names[index]
|
||||||
try:
|
try:
|
||||||
approx_levels = eye_diagram.approximate_levels(self.raw_data[i], n_levels)
|
approx_levels = eye_diagram.approximate_levels(data, self.n_levels)
|
||||||
|
|
||||||
time_bounds = eye_diagram.calculate_time_bounds(self.raw_data[i], approx_levels)
|
time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels)
|
||||||
|
|
||||||
self.eye_stats[i]["time_midpoint"] = (time_bounds[0] + time_bounds[1]) / 2
|
eye_stats["time_midpoint"] = (time_bounds[0] + time_bounds[1]) / 2
|
||||||
|
|
||||||
self.eye_stats[i]["levels"], self.eye_stats[i]["amplitude_clusters"] = eye_diagram.calculate_levels(
|
eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels(
|
||||||
self.raw_data[i], approx_levels, time_bounds
|
data, approx_levels, time_bounds
|
||||||
)
|
)
|
||||||
|
|
||||||
self.eye_stats[i]["amplitudes"] = np.diff(self.eye_stats[i]["levels"])
|
eye_stats["amplitudes"] = np.diff(eye_stats["levels"])
|
||||||
|
|
||||||
self.eye_stats[i]["heights"] = eye_diagram.calculate_eye_heights(
|
eye_stats["heights"] = eye_diagram.calculate_eye_heights(
|
||||||
self.eye_stats[i]["amplitude_clusters"]
|
eye_stats["amplitude_clusters"]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.eye_stats[i]["widths"], self.eye_stats[i]["time_clusters"] = eye_diagram.calculate_eye_widths(
|
eye_stats["widths"], eye_stats["time_clusters"] = eye_diagram.calculate_eye_widths(
|
||||||
self.raw_data[i], self.eye_stats[i]["levels"]
|
data, eye_stats["levels"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# # check if time clusters are valid (upper bound > time_midpoint > lower bound)
|
# # check if time clusters are valid (upper bound > time_midpoint > lower bound)
|
||||||
# # if not: raise ValueError
|
# # if not: raise ValueError
|
||||||
# for j in range(len(self.eye_stats[i]['time_clusters'])):
|
# for j in range(len(eye_stats['time_clusters'])):
|
||||||
# if not (np.max(self.eye_stats[i]['time_clusters'][j][0]) < self.eye_stats[i]["time_midpoint"] < np.min(self.eye_stats[i]['time_clusters'][j][1])):
|
# if not (np.max(eye_stats['time_clusters'][j][0]) < eye_stats["time_midpoint"] < np.min(eye_stats['time_clusters'][j][1])):
|
||||||
# raise ValueError
|
# raise ValueError
|
||||||
|
|
||||||
self.eye_stats[i]["areas"] = self.eye_stats[i]["heights"] * self.eye_stats[i]["widths"]
|
eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"]
|
||||||
self.eye_stats[i]["mean_area"] = np.mean(self.eye_stats[i]["areas"])
|
eye_stats["mean_area"] = np.mean(eye_stats["areas"])
|
||||||
self.eye_stats[i]["min_area"] = np.min(self.eye_stats[i]["areas"])
|
eye_stats["min_area"] = np.min(eye_stats["areas"])
|
||||||
|
|
||||||
self.eye_stats[i]["success"] = True
|
|
||||||
except (RuntimeWarning, UserWarning, ValueError):
|
|
||||||
self.eye_stats[i]["success"] = False
|
|
||||||
self.eye_stats[i]["time_midpoint"] = 0
|
|
||||||
self.eye_stats[i]["levels"] = np.zeros(n_levels)
|
|
||||||
self.eye_stats[i]["amplitude_clusters"] = []
|
|
||||||
self.eye_stats[i]["amplitudes"] = np.zeros(n_levels - 1)
|
|
||||||
self.eye_stats[i]["heights"] = np.zeros(n_levels - 1)
|
|
||||||
self.eye_stats[i]["widths"] = np.zeros(n_levels - 1)
|
|
||||||
self.eye_stats[i]["areas"] = np.zeros(n_levels - 1)
|
|
||||||
self.eye_stats[i]["mean_area"] = 0
|
|
||||||
self.eye_stats[i]["min_area"] = 0
|
|
||||||
|
|
||||||
|
eye_stats["success"] = True
|
||||||
|
except (RuntimeWarning, UserWarning, ValueError):
|
||||||
|
eye_stats["success"] = False
|
||||||
|
eye_stats["time_midpoint"] = 0
|
||||||
|
eye_stats["levels"] = np.zeros(self.n_levels)
|
||||||
|
eye_stats["amplitude_clusters"] = []
|
||||||
|
eye_stats["amplitudes"] = np.zeros(self.n_levels - 1)
|
||||||
|
eye_stats["heights"] = np.zeros(self.n_levels - 1)
|
||||||
|
eye_stats["widths"] = np.zeros(self.n_levels - 1)
|
||||||
|
eye_stats["areas"] = np.zeros(self.n_levels - 1)
|
||||||
|
eye_stats["mean_area"] = 0
|
||||||
|
eye_stats["min_area"] = 0
|
||||||
warnings.resetwarnings()
|
warnings.resetwarnings()
|
||||||
|
return eye_stats
|
||||||
|
|
||||||
|
|
||||||
|
def analyse(self):
|
||||||
|
self.eye_stats = []
|
||||||
|
if self.multi_threaded:
|
||||||
|
with multiprocessing.Pool() as pool:
|
||||||
|
results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.channels)])
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
self.eye_stats.append(result)
|
||||||
|
else:
|
||||||
|
for i in range(self.channels):
|
||||||
|
self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def approximate_levels(data, levels):
|
def approximate_levels(data, levels):
|
||||||
|
|||||||
82160
src/visualization/viz.ipynb
Normal file
82160
src/visualization/viz.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user