This commit is contained in:
Joseph Hopfmüller
2025-01-27 21:05:49 +01:00
parent f38d0ca3bb
commit 249fe1e940
19 changed files with 2266 additions and 880 deletions

3
.gitignore vendored
View File

@@ -163,4 +163,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
tolerance_results/datasets/*
tolerance_results/*
data/*

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fcbdaffa211d6b0b44b3ae1c66645999e95901bfdb2fffee4c45e34a0d901ee1
size 649

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1df90745cc2e6d4b0ad964fca2de1441e6e0b4b8345fbb0fbc1ffe9820674269
size 134481920

View File

@@ -0,0 +1,59 @@
# Baseline Models
## a) D+S, pol_error 0, ortho_error 0, DGD 0
dataset
```raw
data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250118_225918.tar
```
## b) D+S, pol_error 0.4, ortho_error 0, DGD 0
dataset
```raw
data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250116_214816.tar
```
## c) D+S, pol_error 0, ortho_error 0.1, DGD 0
dataset
```raw
data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250117_122319.tar
```
## d) D+S, pol_error 0, ortho_error 0, DGD 10ps (1 T_sym)
birefringence angle pi/2 (worst case)
dataset
```raw
data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250117_144001.tar
```

2
pypho

Submodule pypho updated: dd015f4852...e44fc477fe

View File

@@ -164,10 +164,14 @@ class regenerator(Module):
module = act_function(size=dims[-1], **act_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
module = Scale(size=dims[-1])
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
if self.rotation:
module = rotate()
self.add_module("rotate", module)
# module = Scale(size=dims[-1])
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)

View File

@@ -18,9 +18,11 @@ class DataSettings:
shuffle: bool = True
in_out_delay: float = 0
xy_delay: tuple | float | int = 0
drop_first: int = 1000
drop_first: int = 64
drop_last: int = 64
train_split: float = 0.8
polarisations: tuple | list = (0,)
# cross_pol_interference: float = 0
randomise_polarisations: bool = False
osnr: float | int = None
seed: int = None
@@ -93,6 +95,12 @@ class ModelSettings:
"""
def _early_stop_default_kwargs():
return {
"threshold": 1e-05,
"plateau": 25,
}
@dataclass
class OptimizerSettings:
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
@@ -101,6 +109,9 @@ class OptimizerSettings:
scheduler: str | None = None
scheduler_kwargs: dict | None = None
early_stopping: bool = False
early_stop_kwargs: dict = field(default_factory=_early_stop_default_kwargs)
"""
change to:

View File

@@ -4,6 +4,7 @@ from pathlib import Path
import random
import matplotlib
from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits.axes_grid1 import make_axes_locatable
import torch.nn.utils.parametrize
try:
@@ -46,13 +47,72 @@ from .settings import (
PytorchSettings,
)
from cmcrameri import cm
# from matplotlib import colors as mcolors
# alpha_map = mcolors.LinearSegmentedColormap(
# 'alphamap',
# {
# 'red': [(0, 0, 0), (1, 0, 0)],
# 'green': [(0, 0, 0), (1, 0, 0)],
# 'blue': [(0, 0, 0), (1, 0, 0)],
# 'alpha': [
# (0, 1, 1),
# # (0.2, 0.2, 0.1),
# (1, 0, 0)
# ]
# }
# )
# alpha_map.set_bad(color="#AAAAAA")
def pad_to_size(array, size):
if not hasattr(size, "__len__"):
size = (size, size)
left = (
(size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0
)
right = (
(size[0] - array.shape[0]) // 2 if size[0] is not None else 0
)
top = (
(size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0
)
bottom = (
(size[1] - array.shape[1]) // 2 if size[1] is not None else 0
)
array: np.ndarray = array
if array.ndim == 2:
return np.pad(
array,
(
(left, right),
(top, bottom),
),
constant_values=(np.nan, np.nan),
)
elif array.ndim == 3:
return np.pad(
array,
(
(left, right),
(top, bottom),
(0,0)
),
constant_values=(np.nan, np.nan),
)
def traverse_dict_update(target, source):
for k, v in source.items():
if isinstance(v, dict):
if k not in target:
target[k] = {}
traverse_dict_update(target[k], v)
try:
if k not in target:
target[k] = {}
traverse_dict_update(target[k], v)
except TypeError:
if k not in target.__dict__:
setattr(target, k, {})
traverse_dict_update(target.__dict__[k], v)
else:
try:
target[k] = v
@@ -261,6 +321,7 @@ class PolarizationTrainer:
target_delay=in_out_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
drop_last=self.data_settings.drop_last,
dtype=dtype,
real=not dtype.is_complex,
num_symbols=num_symbols,
@@ -602,6 +663,7 @@ class RegenerationTrainer:
console=None,
checkpoint_path=None,
settings_override=None,
new_model=False,
reset_epoch=False,
):
self.resume = checkpoint_path is not None
@@ -615,12 +677,23 @@ class RegenerationTrainer:
models.regenerator,
torch.nn.utils.parametrizations.orthogonal,
])
# self.new_model = True
self.model_name = datetime.now().strftime("%Y%m%d_%H%M%S")
if self.resume:
print(f"loading checkpoint from {checkpoint_path}")
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
if settings_override is not None:
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
if reset_epoch:
if not new_model:
# self.new_model = False
checkpoint_file = checkpoint_path.split("/")[-1].split(".")[0]
if checkpoint_file.startswith("best"):
self.model_name = "_".join(checkpoint_file.split("_")[1:])
else:
self.model_name = "_".join(checkpoint_file.split("_")[:-1])
if new_model or reset_epoch:
self.checkpoint_dict["epoch"] = -1
self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"]
@@ -654,7 +727,7 @@ class RegenerationTrainer:
self.writer = None
def setup_tb_writer(self, append=None):
log_dir = self.pytorch_settings.summary_dir + "/" + (datetime.now().strftime("%Y%m%d_%H%M%S"))
log_dir = self.pytorch_settings.summary_dir + "/" + self.model_name
if append is not None:
log_dir += "_" + str(append)
@@ -697,8 +770,8 @@ class RegenerationTrainer:
output_dim = self.model_settings.output_dim
# if self.data_settings.polarisations:
output_dim *= 2
if self.data_settings.polarisations:
output_dim *= 2
dtype = getattr(torch, self.data_settings.dtype)
@@ -755,11 +828,13 @@ class RegenerationTrainer:
randomise_polarisations = self.data_settings.randomise_polarisations
polarisations = self.data_settings.polarisations
osnr = self.data_settings.osnr
# cross_pol_interference = self.data_settings.cross_pol_interference
if override is not None:
num_symbols = override.get("num_symbols", None)
config_path = override.get("config_path", config_path)
polarisations = override.get("polarisations", polarisations)
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
# cross_pol_interference = override.get("angle_var", 0)
# get dataset
dataset = FiberRegenerationDataset(
file_path=config_path,
@@ -768,11 +843,13 @@ class RegenerationTrainer:
target_delay=in_out_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
drop_last=self.data_settings.drop_last,
dtype=dtype,
real=not dtype.is_complex,
num_symbols=num_symbols,
randomise_polarisations=randomise_polarisations,
polarisations=polarisations,
# cross_pol_interference=cross_pol_interference,
osnr = osnr,
)
@@ -842,8 +919,10 @@ class RegenerationTrainer:
running_loss = 0.0
self.model.train()
loader_len = len(train_loader)
x_key = "x_stacked"# if self.data_settings.polarisations else "x"
y_key = "y_stacked"# if self.data_settings.polarisations else "y"
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
x_key = "x"
y_key = "y"
for batch_idx, batch in enumerate(train_loader):
x = batch[x_key]
y = batch[y_key]
@@ -855,7 +934,10 @@ class RegenerationTrainer:
angle.to(self.pytorch_settings.device),
)
y_pred = self.model(x, -angle)
# loss = util.complexNN.complex_mse_loss(y_pred*torch.sqrt(torch.ones(1, device=y.device)*5), y*torch.sqrt(torch.ones(1, device=y.device)*5), power=True)
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
loss_value = loss.item()
loss.backward()
optimizer.step()
@@ -898,8 +980,10 @@ class RegenerationTrainer:
self.model.eval()
running_error = 0
x_key = "x_stacked"# if self.data_settings.polarisations else "x"
y_key = "y_stacked"# if self.data_settings.polarisations else "y"
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
x_key = "x"
y_key = "y"
with torch.no_grad():
for _, batch in enumerate(valid_loader):
x = batch[x_key]
@@ -911,7 +995,9 @@ class RegenerationTrainer:
angle.to(self.pytorch_settings.device),
)
y_pred = self.model(x, -angle)
# error = util.complexNN.complex_mse_loss(y_pred*torch.sqrt(torch.ones(1, device=y.device)*5), y*torch.sqrt(torch.ones(1, device=y.device)*5), power=True)
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
error_value = error.item()
running_error += error_value
@@ -928,7 +1014,7 @@ class RegenerationTrainer:
if (epoch + 1) % 10 == 0 or epoch < 10:
# plotting is slow, so only do it every 10 epochs
title_append, subtitle = self.build_title(epoch + 1)
head_fig, eye_fig, powers_fig = self.plot_model_response(
head_fig, eye_fig, weight_fig, powers_fig = self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
@@ -944,6 +1030,11 @@ class RegenerationTrainer:
eye_fig,
epoch + 1,
)
self.writer.add_figure(
"weights",
weight_fig,
epoch + 1,
)
self.writer.add_figure(
"powers",
@@ -967,9 +1058,10 @@ class RegenerationTrainer:
regen = []
timestamps = []
angles = []
x_key = "x_stacked"# if self.data_settings.polarisations else "x"
y_key = "y_stacked"# if self.data_settings.polarisations else "y"
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
x_key = "x"
y_key = "y"
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for batch in loader:
@@ -1056,7 +1148,7 @@ class RegenerationTrainer:
)
title_append, subtitle = self.build_title(0)
head_fig, eye_fig, powers_fig = self.plot_model_response(
head_fig, eye_fig, weight_fig, powers_fig = self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
@@ -1072,6 +1164,11 @@ class RegenerationTrainer:
eye_fig,
0,
)
self.writer.add_figure(
"weights",
weight_fig,
0,
)
self.writer.add_figure(
"powers",
@@ -1103,6 +1200,9 @@ class RegenerationTrainer:
train_loader, valid_loader = self.get_sliced_data()
# train_loader.dataset.fiber_out.to(self.pytorch_settings.device)
# train_loader.dataset.fiber_in.to(self.pytorch_settings.device)
optimizer_name = self.optimizer_settings.optimizer
# lr = self.optimizer_settings.learning_rate
@@ -1132,6 +1232,7 @@ class RegenerationTrainer:
# except ValueError:
# pass
self.early_stop_vals = {"min_loss": float("inf"), "plateau_cnt": 0}
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
enable_progress = True
if enable_progress:
@@ -1147,29 +1248,64 @@ class RegenerationTrainer:
epoch,
enable_progress=enable_progress,
)
if self.early_stop(loss):
self.save_model_checkpoints(epoch, loss)
break
if self.optimizer_settings.scheduler is not None:
self.scheduler.step(loss)
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch)
if self.pytorch_settings.save_models and self.model is not None:
save_path = (
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
checkpoint = self.build_checkpoint_dict(loss, epoch)
self.save_checkpoint(checkpoint, save_path)
if loss < self.best["loss"]:
self.best = checkpoint
save_path = (
Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
self.save_checkpoint(self.best, save_path)
self.save_model_checkpoints(epoch, loss)
self.writer.flush()
save_path = (Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar")
print(f"Training complete. Best checkpoint: {save_path}")
self.writer.close()
return self.best
def early_stop(self, loss):
# not stopping early at all
if not self.optimizer_settings.early_stopping:
return False
# stopping because of abs threshold
if (loss_thr := self.optimizer_settings.early_stop_kwargs.get("threshold", None)) is not None:
if loss <= loss_thr:
print(f"Early stop: loss is below threshold ({loss:.2e} <= {loss_thr:.2e})")
return True
# update vals
if loss < self.early_stop_vals["min_loss"]:
self.early_stop_vals["min_loss"] = loss
self.early_stop_vals["plateau_cnt"] = 0
return False
# stopping because of plateau
if (plateau_thresh := self.optimizer_settings.early_stop_kwargs.get("plateau", None)) is not None:
self.early_stop_vals["plateau_cnt"] += 1
if self.early_stop_vals["plateau_cnt"] >= plateau_thresh:
print(f"Early stop: loss plateau length over threshold ({self.early_stop_vals["plateau_cnt"]} >= {plateau_thresh})")
return True
# no stop
return False
def save_model_checkpoints(self, epoch, loss):
if self.pytorch_settings.save_models and self.model is not None:
save_path = (
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
checkpoint = self.build_checkpoint_dict(loss, epoch)
self.save_checkpoint(checkpoint, save_path)
if loss < self.best["loss"]:
self.best = checkpoint
save_path = (
Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
self.save_checkpoint(self.best, save_path)
def _plot_model_response_powers(self, powers, layer_names, title_append="", subtitle="", show=True):
powers = [power / powers[0] for power in powers]
fig, ax = plt.subplots()
@@ -1190,6 +1326,77 @@ class RegenerationTrainer:
plt.show()
return fig
def _plot_model_weights(self, model, title_append="", subtitle="", show=True):
model_params = []
plots = []
dims = []
for num, (layer_name, layer) in enumerate(model.named_children()):
onn_weights = layer.ONN.weight
onn_weights = onn_weights.detach().cpu().numpy()
onn_values = np.abs(onn_weights).real
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
model_params.append({layer_name: onn_weights})
plots.append({layer_name: (num, onn_values, onn_angles)})
dims.append(onn_weights.shape[0])
max_size = np.max(dims)
for plot in plots:
layer_name, (num, onn_values, onn_angles) = plot.popitem()
if num == 0:
value_img = onn_values
angle_img = onn_angles
onn_angles = pad_to_size(onn_angles, (max_size, None))
onn_values = pad_to_size(onn_values, (max_size, None))
else:
onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1))
onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1))
value_img = np.concatenate((value_img, onn_values), axis=1)
angle_img = np.concatenate((angle_img, onn_angles), axis=1)
value_img = np.ma.array(value_img, mask=np.isnan(value_img))
angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img))
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(18, 6.5))
fig.tight_layout()
dividers = map(make_axes_locatable, axs)
caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers))
masked_value_img = value_img
cmap = cm.batlow
cmap.set_bad(color="#AAAAAA")
im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1)
fig.colorbar(im_val, cax=caxs[0], orientation="vertical")
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
cmap = cm.romaO
cmap.set_bad(color="#AAAAAA")
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi)
cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)])
cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", ""])
axs[0].axis("off")
axs[1].axis("off")
axs[0].set_title("Values")
axs[1].set_title("Angles")
title = "Layer Weights"
if title_append:
title += f" {title_append}"
if subtitle:
title += f"\n{subtitle}"
fig.suptitle(title)
if show:
plt.show()
return fig
def _plot_model_response_eye(
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
):
@@ -1354,7 +1561,7 @@ class RegenerationTrainer:
data_settings_backup = copy.deepcopy(self.data_settings)
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
self.data_settings.drop_first = int(64 + random.randint(0, 1000))
self.data_settings.shuffle = False
self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols)
@@ -1363,7 +1570,7 @@ class RegenerationTrainer:
if isinstance(self.data_settings.config_path, (list, tuple))
else self.data_settings.config_path
)
fiber_length = int(float(str(config_path).split("-")[4]) / 1000)
# fiber_length = int(float(str(config_path).split("-")[4]) / 1000)
if not hasattr(self, "_plot_loader"):
self._plot_loader, _ = self.get_sliced_data(
override={
@@ -1376,6 +1583,7 @@ class RegenerationTrainer:
}
)
self._sps = self._plot_loader.dataset.samples_per_symbol
fiber_length = float(self._plot_loader.dataset.config["fiber"]["length"])/1000
self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup
@@ -1403,7 +1611,7 @@ class RegenerationTrainer:
import gc
head_fig = self._plot_model_response_head(
fiber_out_rot[: self.pytorch_settings.head_symbols * self._sps],
fiber_out[: self.pytorch_settings.head_symbols * self._sps],
fiber_in[: self.pytorch_settings.head_symbols * self._sps],
regen[: self.pytorch_settings.head_symbols * self._sps],
angles[: self.pytorch_settings.head_symbols * self._sps],
@@ -1417,7 +1625,7 @@ class RegenerationTrainer:
# raise NotImplementedError("Eye diagram not implemented")
eye_fig = self._plot_model_response_eye(
fiber_in[: self.pytorch_settings.eye_symbols * self._sps],
fiber_out_rot[: self.pytorch_settings.eye_symbols * self._sps],
fiber_out[: self.pytorch_settings.eye_symbols * self._sps],
regen[: self.pytorch_settings.eye_symbols * self._sps],
timestamps=timestamps[: self.pytorch_settings.eye_symbols * self._sps],
labels=("fiber in", "fiber out", "regen"),
@@ -1426,9 +1634,11 @@ class RegenerationTrainer:
subtitle=subtitle,
show=show,
)
weight_fig = self._plot_model_weights(model, title_append=title_append, subtitle=subtitle, show=show)
gc.collect()
return head_fig, eye_fig, power_fig
return head_fig, eye_fig, weight_fig, power_fig
def build_title(self, number: int):
title_append = f"epoch {number}"

View File

@@ -1,4 +1,6 @@
import os
from pathlib import Path
import sys
from matplotlib import pyplot as plt
import numpy as np
import torch
@@ -26,6 +28,28 @@ from hypertraining import models
# constant_values=(-np.inf, -np.inf),
# )
def register_puccs_cmap(puccs_path=None):
puccs_path = Path(__file__).resolve().parent / 'puccs.csv' if puccs_path is None else puccs_path
colors = []
# keys = None
with open(puccs_path, "r") as f:
for i, line in enumerate(f.readlines()):
elements = tuple(line.split(","))
# if i == 0:
# # keys = elements
# continue
# else:
try:
colors.append(tuple(map(float, elements[4:])))
except ValueError:
continue
# colors = []
# for current in puccs_csv_data:
# colors.append(tuple(current[4:]))
from matplotlib.colors import LinearSegmentedColormap
import matplotlib as mpl
mpl.colormaps.register(LinearSegmentedColormap.from_list('puccs', colors))
def pad_to_size(array, size):
if not hasattr(size, "__len__"):
@@ -65,7 +89,7 @@ def pad_to_size(array, size):
constant_values=(np.nan, np.nan),
)
def model_plot(model_path):
def model_plot(model_path, show=True):
torch.serialization.add_safe_globals([
*util.complexNN.__all__,
GlobalSettings,
@@ -81,173 +105,113 @@ def model_plot(model_path):
dims = checkpoint_dict["model_kwargs"].pop("dims")
model = models.regenerator(*dims, **checkpoint_dict["model_kwargs"])
model.load_state_dict(checkpoint_dict["model_state_dict"])
model.load_state_dict(checkpoint_dict["model_state_dict"], strict=False)
model_params = []
plots = []
max_size = np.max(dims)
# max_act_size = np.max(dims[1:])
angles = [None, None]
weights = [None, None]
# angles = [None, None]
# weights = [None, None]
for num, (layer_name, layer) in enumerate(model.named_children()):
# each layer contains an "ONN" layer and an "activation" layer
# activation layer is approximately the same for all layers and nodes -> rotation by 90 degrees
onn_weights = layer.ONN.weight.T
onn_weights = layer.ONN.weight
onn_weights = onn_weights.detach().cpu().numpy()
onn_values = np.abs(onn_weights).real
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
act = layer.activation
act_values = np.ones((act.size, 1))
act_values = np.nan * act_values
act_angles = act.phase.unsqueeze(-1).detach().cpu().numpy()
...
# act_phi_bias = torch.pi * act.V_bias / (act.V_pi + 1e-8)
# act_phi_gain = torch.pi * (act.alpha * act.gain * act.responsivity) / (act.V_pi + 1e-8)
# xs = (0.01, 0.1, 1)
# act_values = np.zeros((act.size, len(xs)*2))
# act_angles = np.zeros((act.size, len(xs)*2))
# act_values[:,:] = np.nan
# act_angles[:,:] = np.nan
# for xi, x in enumerate(xs):
# phi_intermediate = act_phi_gain * x**2 + act_phi_bias
# act_resulting_gain = (
# 1j
# * torch.sqrt(1-act.alpha)
# * torch.exp(-0.5j * phi_intermediate)
# * torch.cos(0.5 * phi_intermediate)
# * x
# )
# act_resulting_gain = act_resulting_gain.detach().cpu().numpy()
# act_values[:, xi*2] = np.abs(act_resulting_gain).real
# act_angles[:, xi*2] = np.mod(np.angle(act_resulting_gain), 2*np.pi).real
# if angles[0] is None or angles[0] > np.min(onn_angles.flatten()):
# angles[0] = np.min(onn_angles.flatten())
# if angles[1] is None or angles[1] < np.max(onn_angles.flatten()):
# angles[1] = np.max(onn_angles.flatten())
# if weights[0] is None or weights[0] > np.min(onn_weights.flatten()):
# weights[0] = np.min(onn_weights.flatten())
# if weights[1] is None or weights[1] < np.max(onn_weights.flatten()):
# weights[1] = np.max(onn_weights.flatten())
model_params.append({layer_name: onn_weights})
plots.append({layer_name: (num, onn_values, onn_angles, act_values, act_angles)})
plots.append({layer_name: (num, onn_values, onn_angles)})#, act_values, act_angles)})
# fig, axs = plt.subplots(3, len(model_params)*2-1, figsize=(20, 5))
for plot in plots:
layer_name, (num, onn_values, onn_angles, act_values, act_angles) = plot.popitem()
# for_plot[:, :, 0] = (for_plot[:, :, 0] - angles[0]) / (angles[1] - angles[0])
# for_plot[:, :, 1] = (for_plot[:, :, 1] - weights[0]) / (weights[1] - weights[0])
onn_values = np.ma.array(onn_values, mask=np.isnan(onn_values))
onn_values = onn_values - np.min(onn_values)
onn_values = onn_values / np.max(onn_values)
act_values = np.ma.array(act_values, mask=np.isnan(act_values))
act_values = act_values - np.min(act_values)
act_values = act_values / np.max(act_values)
onn_values = onn_values
onn_values = pad_to_size(onn_values, (max_size, None))
act_values = act_values
act_values = pad_to_size(act_values, (max_size, 3))
onn_angles = onn_angles / np.pi
onn_angles = pad_to_size(onn_angles, (max_size, None))
act_angles = act_angles / np.pi
act_angles = pad_to_size(act_angles, (max_size, 3))
# onn_angles = onn_angles - np.min(onn_angles)
# onn_angles = onn_angles / np.max(onn_angles)
# act_angles = act_angles - np.min(act_angles)
# act_angles = act_angles / np.max(act_angles)
layer_name, (num, onn_values, onn_angles) = plot.popitem()
if num == 0:
value_img = np.concatenate((onn_values, act_values), axis=1)
angle_img = np.concatenate((onn_angles, act_angles), axis=1)
value_img = onn_values
angle_img = onn_angles
onn_angles = pad_to_size(onn_angles, (max_size, None))
onn_values = pad_to_size(onn_values, (max_size, None))
else:
value_img = np.concatenate((value_img, onn_values, act_values), axis=1)
angle_img = np.concatenate((angle_img, onn_angles, act_angles), axis=1)
onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1))
onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1))
value_img = np.concatenate((value_img, onn_values), axis=1)
angle_img = np.concatenate((angle_img, onn_angles), axis=1)
value_img = np.ma.array(value_img, mask=np.isnan(value_img))
angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img))
# from cmcrameri import cm
from cmap import Colormap as cm
import scicomap as sc
# from matplotlib import colors as mcolors
# alpha_map = mcolors.LinearSegmentedColormap(
# 'alphamap',
# {
# 'red': [(0, 0, 0), (1, 0, 0)],
# 'green': [(0, 0, 0), (1, 0, 0)],
# 'blue': [(0, 0, 0), (1, 0, 0)],
# 'alpha': [
# (0, 1, 1),
# # (0.2, 0.2, 0.1),
# (1, 0, 0)
# ]
# }
# )
# alpha_map.set_bad(color="#AAAAAA")
from mpl_toolkits.axes_grid1 import make_axes_locatable
# -np.inf to np.nan
# value_img[value_img == -np.inf] = np.nan
# angle_img += move_to_location_in_size(onn_angles, ((max_size+3)*num, 0), img_overall_size)
# angle_img += move_to_location_in_size(act_angles, ((max_size+3)*(num+1) + 2, 0), img_overall_size)
from cmcrameri import cm
from matplotlib import colors as mcolors
alpha_map = mcolors.LinearSegmentedColormap(
'alphamap',
{
'red': [(0, 0, 0), (1, 0, 0)],
'green': [(0, 0, 0), (1, 0, 0)],
'blue': [(0, 0, 0), (1, 0, 0)],
'alpha': [
(0, 1, 1),
# (0.2, 0.2, 0.1),
(1, 0, 0)
]
}
)
alpha_map.set_bad(color="#AAAAAA")
fig, axs = plt.subplots(3, 1, sharex=True, sharey=True, figsize=(7, 8.5))
fig.tight_layout()
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(16, 5))
# fig.tight_layout()
dividers = map(make_axes_locatable, axs)
caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers))
# masked_value_img = np.ma.masked_where(np.isnan(value_img), value_img)
masked_value_img = value_img
cmap = cm.batlowW
cmap = cm('google:turbo').to_matplotlib()
# cmap = sc.ScicoSequential("rainbow").get_mpl_color_map()
cmap.set_bad(color="#AAAAAA")
im_val = axs[0].imshow(masked_value_img, cmap=cmap)
im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1)
fig.colorbar(im_val, cax=caxs[0], orientation="vertical")
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
cmap = cm.romaO
# cmap = cm('crameri:romao').to_matplotlib()
# cmap = plt.get_cmap('puccs')
# cmap = sc.ScicoCircular("colorwheel").get_mpl_color_map()
cmap = cm('colorcet:CET_C8').to_matplotlib()
cmap.set_bad(color="#AAAAAA")
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap)
im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap)
im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map)
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi)
cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)])
cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", ""])
# im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap)
# im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map)
axs[0].axis("off")
axs[1].axis("off")
axs[2].axis("off")
# axs[2].axis("off")
axs[0].set_title("Values")
axs[1].set_title("Angles")
axs[2].set_title("Values and Angles")
# axs[2].set_title("Values and Angles")
...
plt.show()
if show:
plt.show()
return fig
# model = models.regenerator(*dims, **model_kwargs)
if __name__ == "__main__":
model_plot(".models/best_20250105_145719.tar")
register_puccs_cmap()
if len(sys.argv) > 1:
model_plot(sys.argv[1])
else:
print("Please provide a model path as an argument")
# model_plot(".models/best_20250114_224234.tar")

View File

@@ -0,0 +1,102 @@
"x","L","a","b","R","G","B"
0.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
0.01,0.5374499525557803,0.604014067614707,0.6777967519386492,0.8956274406155226,0.27553288030331824,0.
0.02,0.5560867887452998,0.5680836759482211,0.6855816828789898,0.9019507507843885,0.318608215541461,0.
0.03,0.5746877595125583,0.5322224300667823,0.6933516322080414,0.907905487190649,0.3580633000693721,0.
0.04,0.5932314662487472,0.49647158484797804,0.7010976613543587,0.9134808162089558,0.3949845524063657,0.
0.05,0.6117000836392819,0.46086550613202343,0.7088123243737041,0.918668356138916,0.43002019316005363,0.
0.06,0.6300828534995973,0.4254249348741487,0.7164911273850869,0.923462736751354,0.4635961938811463,0.
0.07,0.6483763163456417,0.3901565406944371,0.7241326253017896,0.9278609626724071,0.49601354353255284,0.
0.08,0.6665840140182806,0.3550534951951814,0.7317382976124045,0.9318616057744784,0.5274983630587982,0.
0.09,0.6847162776119433,0.3200958808181962,0.7393124597949372,0.9354640163365924,0.5582303922647159,0.
0.1,0.7027902128942014,0.2852507189547545,0.7468622572263107,0.9386675557407496,0.5883604892249517,0.004034952213848706
0.11,0.7208298719332069,0.25047163906104203,0.7543977368741345,0.9414708123927996,0.6180221032545026,0.016031521294251994
0.12,0.7388665670611175,0.2156982733607376,0.7619319784446927,0.943870754968487,0.6473392272576862,0.029857267582036696
0.13,0.7569392765472108,0.18085547473834482,0.7694812638396673,0.9458617774020323,0.676432172396153,0.045365670193636125
0.14,0.7750950944867471,0.14585244938794778,0.7770652650825484,0.9474345911958609,0.7054219201084561,0.06017985923530026
0.15,0.793389684293558,0.11058188251425949,0.7847072337503834,0.9485749196617762,0.7344334940032564,0.07418869502646075
0.16,0.8117919447684838,0.07510373484536464,0.792394178330817,0.9492596163836376,0.7634480277996188,0.08767517868137237
0.17,0.8293050962981561,0.03629277424762101,0.799038155466063,0.9462308253550155,0.7922009241807345,0.10066327128139077
0.18,0.8213303100752708,-0.0062517290795987,0.7879999288492758,0.9088702681901394,0.7940579017644396,0.10139639009534024
0.19,0.8134831311534617,-0.048115463155645855,0.7771383286984362,0.8716809050191757,0.7954897210083888,0.10232311621802098
0.2,0.80558613530069,-0.0902449644291895,0.7662077749032042,0.8337524177888596,0.7965471523787845,0.10344968926026826
0.21,0.7975860185564765,-0.13292460297117392,0.7551344872795225,0.7947193410849823,0.7972381033243311,0.10477682283894393
0.22,0.7894147026971006,-0.17651756772919341,0.7438242359834689,0.7540941866826836,0.7975605026647324,0.10631182441371936
0.23,0.7809997374598548,-0.2214103719409295,0.7321767396537806,0.7112894518675287,0.7974995317311054,0.1080672415170634
0.24,0.7722646970273015,-0.2680107379394189,0.7200862142018722,0.6655745739336695,0.7970267795229349,0.11006041388465265
0.25,0.7631307298557146,-0.3167393290089981,0.7074435179925446,0.6160047476007512,0.7960993904970947,0.11231257117602686
0.26,0.7535192192483822,-0.36801555555407994,0.6941398344519211,0.5612859274945571,0.794659599537827,0.11484733363789801
0.27,0.7433557597838075,-0.42223636134393283,0.6800721760037781,0.4994862901720824,0.7926351396848288,0.11768844813479104
0.28,0.732575139048096,-0.479749646583324,0.6651502794883674,0.42731393423789277,0.7899410218414098,0.12085678487511567
0.29,0.7211269294461059,-0.5408244362880141,0.6493043460161184,0.3378265607222193,0.786483110019224,0.124366774034814
0.3,0.7090756028785993,-0.6051167807996883,0.6326236137723747,0.2098475715121697,0.7821998608677176,0.12819222127525928
0.31,0.7094510768540225,-0.6165036055456403,0.5630307498747129,0.15061488620640032,0.7845112116922692,0.21943537230975235
0.32,0.7174669421288304,-0.5917687864932311,0.4797229624661701,0.18766933782916642,0.7905828987725732,0.31091344246312086
0.33,0.7249009746435938,-0.5688293479200438,0.40246208306061504,0.21160609617940718,0.7962175427587832,0.38519766326885596
0.34,0.7317072855135611,-0.5478268906666535,0.3317250285377912,0.22717569971119178,0.8013847719431052,0.4490960048955565
0.35,0.7379328517830899,-0.5286164561226088,0.26702357292455026,0.23690087622812972,0.8061220291668977,0.5056371468159843
0.36,0.7436229063122554,-0.5110584677642499,0.20788761731555405,0.24226377668817778,0.8104638164122776,0.5563570758573497
0.37,0.7488251728809415,-0.4950056627547577,0.15382117501783654,0.24424372086048424,0.8144455902164638,0.6022301663745243
0.38,0.7535943992285348,-0.48028910419451787,0.10425526029155024,0.24352232677523483,0.818107753931944,0.6440238320299774
0.39,0.757994865186593,-0.4667104416936734,0.05852182167144754,0.240562414747303,0.8214980148949816,0.6824536572462205
0.4,0.7620994844391137,-0.4540446830999986,0.015863077249098356,0.2356325204239052,0.8246710357361025,0.7182393675419642
0.41,0.7659871096124125,-0.4420485102716773,-0.024540477496154123,0.22880568593963535,0.8276865975886148,0.7521146815529202
0.42,0.7697410958994951,-0.4304647113488041,-0.06355514164248566,0.21993360985514526,0.8306086550266585,0.7848331944479765
0.43,0.773446484628189,-0.4190308715098135,-0.10206473803580057,0.20858849290850018,0.833503273690861,0.8171544357676854
0.44,0.7771893686864673,-0.4074813310994203,-0.14096401824224686,0.1939295692427068,0.8364382500400466,0.8498448067259188
0.45,0.7810574093604746,-0.3955455908045306,-0.18116403397486242,0.17438366103820427,0.839483669055626,0.8836865023336339
0.46,0.7851360804917298,-0.3829599011818591,-0.2235531031349741,0.14679145002531463,0.8427091517444469,0.9194481212717681
0.47,0.789525027020907,-0.369416784561489,-0.26916682191206776,0.10278921007810798,0.8461971304126237,0.9580316568065935
0.48,0.7942371698732826,-0.35487637041943493,-0.3181394757087982,0.0013920913109500188,0.8499626968466341,0.9995866371771526
0.49,0.7773897680996302,-0.31852357140025195,-0.34537976514700053,0.10740420703601522,0.8254781216972907,1.
0.5,0.7604011244310231,-0.28211213216592784,-0.3722846952738428,0.1581725581872408,0.8008522647497104,1.
0.51,0.7433440454962605,-0.2455540169176899,-0.3992980063927199,0.19300141807932156,0.7761561224913385,1.
0.52,0.7262590833969331,-0.20893614020926626,-0.42635547610418184,0.2194621842292243,0.751443124097109,1.
0.53,0.709058602701224,-0.17207067467417486,-0.453595892719742,0.2405673704012788,0.7265803324554873,1.
0.54,0.6915768892539101,-0.1346024482921609,-0.48128169789479536,0.25788347992973676,0.701321051230534,1.
0.55,0.6736331627810209,-0.09614399811510127,-0.5096991935104321,0.2722888922216317,0.6753950894563805,1.
0.56,0.6551463184003872,-0.05652149358027936,-0.5389768254408652,0.28422807900785235,0.6486730893521468,1.
0.57,0.6361671326276888,-0.01584376303510615,-0.5690341788729347,0.293907374075009,0.6212117649042732,1.
0.58,0.6168396823565967,0.025580396234342995,-0.5996430791016598,0.301442767979156,0.5931976878638505,1.
0.59,0.5973210287815495,0.06741435793529688,-0.6305547881733555,0.30694603901024253,0.5648312189065924,1.
0.6,0.5777303704171711,0.10940264614179468,-0.661580531294122,0.3105418468883679,0.5362525958007331,1.
0.61,0.5581475370499237,0.15137416317967575,-0.6925938819599547,0.3123531986526998,0.5075386530652202,1.
0.62,0.5386227795100639,0.19322120739317136,-0.7235152578861672,0.31248922600720636,0.4787151440558522,1.
0.63,0.5191666876024412,0.23492108185347996,-0.754327887989376,0.31103663081260624,0.44973844514160927,1.
0.64,0.4996990584326256,0.2766456839100268,-0.7851587896650079,0.30803814950244496,0.4204116611935119,1.
0.65,0.479957679121191,0.3189570094767831,-0.8164232296840259,0.30343473603466015,0.390226489453496,1.
0.66,0.4600072725872886,0.3617163391430824,-0.8480187063016573,0.29717122075330515,0.3591178757512998,1.
0.67,0.44600100870220305,0.4113853615984094,-0.8697728377551008,0.3178994129506999,0.3295740682997879,1.
0.68,0.4574651571354146,0.44026390446569547,-0.8504539292487465,0.3842479358768364,0.3280946443367561,1.
0.69,0.4691809168948424,0.46977626401045774,-0.830711015748157,0.44293649140770447,0.3260767554252525,1.
0.7,0.4811696900083858,0.49997635259991063,-0.8105080314416201,0.49708450874457527,0.3234487047238236,1.
0.71,0.49350094811609174,0.5310391714342613,-0.7897279055963483,0.5485591109413528,0.3201099534066949,1.
0.72,0.5062548753068121,0.5631667067020758,-0.7682355153041539,0.5985798481027601,0.3159263917472715,1.
0.73,0.5195243020949684,0.5965928013272943,-0.7458744264238399,0.6480500606439057,0.31071717884730565,1.
0.74,0.5334043922713477,0.6315571758288618,-0.7224842728734379,0.6976685401842261,0.3042411890803418,1.
0.75,0.5479805812358602,0.6682750446095802,-0.697921082452685,0.7479712773579563,0.29618040787504757,1.
0.76,0.5633244502526606,0.7069267230777347,-0.6720642293775535,0.7993701361353484,0.28611136999256687,1.
0.77,0.5794956601139,0.7476624986056212,-0.6448131757501174,0.8521918014427678,0.2734527325942473,1.
0.78,0.5965429098573916,0.7906050455688622,-0.6160858559672187,0.9067003897516911,0.2573693489198746,1.
0.79,0.6145761476424179,0.8360313267658297,-0.5856969899409387,0.963334644317004,0.23648492980159264,1.
0.8,0.6232910688128902,0.859291371252556,-0.5300995185388214,1.,0.21867949406239662,0.9712088595948508
0.81,0.6159984336377875,0.8439887543380684,-0.44635440435952856,1.,0.21606849746358275,0.9041480210597966
0.82,0.6091642745073532,0.8296481879180277,-0.36787420852419694,1.,0.21421830096504035,0.8419706002336461
0.83,0.6025478038652375,0.8157644115969636,-0.2918938425681935,1.,0.21295365915197917,0.7823908751330636
0.84,0.5961857222953111,0.8024144366282877,-0.21883475834162458,0.9971140114799418,0.21220068235083267,0.7256713129328118
0.85,0.5900921771070883,0.7896279492437488,-0.1488594167412921,0.993273906363258,0.2118788857127918,0.671860243327784
0.86,0.5842771639541229,0.7774259239818333,-0.08208260304413262,0.9887084084529413,0.21191070453347688,0.6209624706933893
0.87,0.578741582584259,0.7658102488427286,-0.018514649521559012,0.9835846378805114,0.2122246941077346,0.5728987835613306
0.88,0.5734741590353537,0.7547572669288056,0.04197390858426542,0.9780378159372328,0.21275878699579343,0.5274829957183049
0.89,0.5684517008574971,0.7442183119942206,0.09964940221121898,0.9721670725313721,0.21346242315895625,0.4844270603851604
0.9,0.5636419856510335,0.7341257696545772,0.15488185789614228,0.9660363209686843,0.21429691147008262,0.4433660148378527
0.91,0.5590069340453534,0.7243997354573974,0.20810856081277884,0.9596781387247791,0.2152344151262528,0.4038812338146013
0.92,0.5545051525321143,0.7149533506766244,0.25980485409830323,0.9530986696850675,0.21625626438013962,0.3655130449917989
0.93,0.5500961975299247,0.705701749880514,0.3104351723857584,0.9462863346513658,0.21735046958786286,0.327780364198278
0.94,0.545740378056064,0.6965616468647046,0.36045530782708896,0.93921469089265,0.21851014470332586,0.29014917175372823
0.95,0.5414004092067859,0.6874548042588865,0.41029342232076466,0.9318478255642132,0.21973168075163751,0.2519897371806688
0.96,0.5370416605957644,0.6783085548415655,0.46034719456417006,0.9241434776436454,0.22101341980094052,0.2124579038400577
0.97,0.5326309593934517,0.6690532898786764,0.5109975653738162,0.9160532016485884,0.22235495330179011,0.17018252385769012
0.98,0.5281374148557197,0.6596241892863608,0.5625992691950712,0.90752576202319,0.22375597459867458,0.1223073280126531
0.99,0.5235317096396147,0.6499597345521199,0.615488972291106,0.8985077346125597,0.22521565729028564,0.05933950582860665
1.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
1 x L a b R G B
2 0. 0.5187848173343539 0.6399990176455989 0.67 0.8889427469969852 0.22673227640012172 0.
3 0.01 0.5374499525557803 0.604014067614707 0.6777967519386492 0.8956274406155226 0.27553288030331824 0.
4 0.02 0.5560867887452998 0.5680836759482211 0.6855816828789898 0.9019507507843885 0.318608215541461 0.
5 0.03 0.5746877595125583 0.5322224300667823 0.6933516322080414 0.907905487190649 0.3580633000693721 0.
6 0.04 0.5932314662487472 0.49647158484797804 0.7010976613543587 0.9134808162089558 0.3949845524063657 0.
7 0.05 0.6117000836392819 0.46086550613202343 0.7088123243737041 0.918668356138916 0.43002019316005363 0.
8 0.06 0.6300828534995973 0.4254249348741487 0.7164911273850869 0.923462736751354 0.4635961938811463 0.
9 0.07 0.6483763163456417 0.3901565406944371 0.7241326253017896 0.9278609626724071 0.49601354353255284 0.
10 0.08 0.6665840140182806 0.3550534951951814 0.7317382976124045 0.9318616057744784 0.5274983630587982 0.
11 0.09 0.6847162776119433 0.3200958808181962 0.7393124597949372 0.9354640163365924 0.5582303922647159 0.
12 0.1 0.7027902128942014 0.2852507189547545 0.7468622572263107 0.9386675557407496 0.5883604892249517 0.004034952213848706
13 0.11 0.7208298719332069 0.25047163906104203 0.7543977368741345 0.9414708123927996 0.6180221032545026 0.016031521294251994
14 0.12 0.7388665670611175 0.2156982733607376 0.7619319784446927 0.943870754968487 0.6473392272576862 0.029857267582036696
15 0.13 0.7569392765472108 0.18085547473834482 0.7694812638396673 0.9458617774020323 0.676432172396153 0.045365670193636125
16 0.14 0.7750950944867471 0.14585244938794778 0.7770652650825484 0.9474345911958609 0.7054219201084561 0.06017985923530026
17 0.15 0.793389684293558 0.11058188251425949 0.7847072337503834 0.9485749196617762 0.7344334940032564 0.07418869502646075
18 0.16 0.8117919447684838 0.07510373484536464 0.792394178330817 0.9492596163836376 0.7634480277996188 0.08767517868137237
19 0.17 0.8293050962981561 0.03629277424762101 0.799038155466063 0.9462308253550155 0.7922009241807345 0.10066327128139077
20 0.18 0.8213303100752708 -0.0062517290795987 0.7879999288492758 0.9088702681901394 0.7940579017644396 0.10139639009534024
21 0.19 0.8134831311534617 -0.048115463155645855 0.7771383286984362 0.8716809050191757 0.7954897210083888 0.10232311621802098
22 0.2 0.80558613530069 -0.0902449644291895 0.7662077749032042 0.8337524177888596 0.7965471523787845 0.10344968926026826
23 0.21 0.7975860185564765 -0.13292460297117392 0.7551344872795225 0.7947193410849823 0.7972381033243311 0.10477682283894393
24 0.22 0.7894147026971006 -0.17651756772919341 0.7438242359834689 0.7540941866826836 0.7975605026647324 0.10631182441371936
25 0.23 0.7809997374598548 -0.2214103719409295 0.7321767396537806 0.7112894518675287 0.7974995317311054 0.1080672415170634
26 0.24 0.7722646970273015 -0.2680107379394189 0.7200862142018722 0.6655745739336695 0.7970267795229349 0.11006041388465265
27 0.25 0.7631307298557146 -0.3167393290089981 0.7074435179925446 0.6160047476007512 0.7960993904970947 0.11231257117602686
28 0.26 0.7535192192483822 -0.36801555555407994 0.6941398344519211 0.5612859274945571 0.794659599537827 0.11484733363789801
29 0.27 0.7433557597838075 -0.42223636134393283 0.6800721760037781 0.4994862901720824 0.7926351396848288 0.11768844813479104
30 0.28 0.732575139048096 -0.479749646583324 0.6651502794883674 0.42731393423789277 0.7899410218414098 0.12085678487511567
31 0.29 0.7211269294461059 -0.5408244362880141 0.6493043460161184 0.3378265607222193 0.786483110019224 0.124366774034814
32 0.3 0.7090756028785993 -0.6051167807996883 0.6326236137723747 0.2098475715121697 0.7821998608677176 0.12819222127525928
33 0.31 0.7094510768540225 -0.6165036055456403 0.5630307498747129 0.15061488620640032 0.7845112116922692 0.21943537230975235
34 0.32 0.7174669421288304 -0.5917687864932311 0.4797229624661701 0.18766933782916642 0.7905828987725732 0.31091344246312086
35 0.33 0.7249009746435938 -0.5688293479200438 0.40246208306061504 0.21160609617940718 0.7962175427587832 0.38519766326885596
36 0.34 0.7317072855135611 -0.5478268906666535 0.3317250285377912 0.22717569971119178 0.8013847719431052 0.4490960048955565
37 0.35 0.7379328517830899 -0.5286164561226088 0.26702357292455026 0.23690087622812972 0.8061220291668977 0.5056371468159843
38 0.36 0.7436229063122554 -0.5110584677642499 0.20788761731555405 0.24226377668817778 0.8104638164122776 0.5563570758573497
39 0.37 0.7488251728809415 -0.4950056627547577 0.15382117501783654 0.24424372086048424 0.8144455902164638 0.6022301663745243
40 0.38 0.7535943992285348 -0.48028910419451787 0.10425526029155024 0.24352232677523483 0.818107753931944 0.6440238320299774
41 0.39 0.757994865186593 -0.4667104416936734 0.05852182167144754 0.240562414747303 0.8214980148949816 0.6824536572462205
42 0.4 0.7620994844391137 -0.4540446830999986 0.015863077249098356 0.2356325204239052 0.8246710357361025 0.7182393675419642
43 0.41 0.7659871096124125 -0.4420485102716773 -0.024540477496154123 0.22880568593963535 0.8276865975886148 0.7521146815529202
44 0.42 0.7697410958994951 -0.4304647113488041 -0.06355514164248566 0.21993360985514526 0.8306086550266585 0.7848331944479765
45 0.43 0.773446484628189 -0.4190308715098135 -0.10206473803580057 0.20858849290850018 0.833503273690861 0.8171544357676854
46 0.44 0.7771893686864673 -0.4074813310994203 -0.14096401824224686 0.1939295692427068 0.8364382500400466 0.8498448067259188
47 0.45 0.7810574093604746 -0.3955455908045306 -0.18116403397486242 0.17438366103820427 0.839483669055626 0.8836865023336339
48 0.46 0.7851360804917298 -0.3829599011818591 -0.2235531031349741 0.14679145002531463 0.8427091517444469 0.9194481212717681
49 0.47 0.789525027020907 -0.369416784561489 -0.26916682191206776 0.10278921007810798 0.8461971304126237 0.9580316568065935
50 0.48 0.7942371698732826 -0.35487637041943493 -0.3181394757087982 0.0013920913109500188 0.8499626968466341 0.9995866371771526
51 0.49 0.7773897680996302 -0.31852357140025195 -0.34537976514700053 0.10740420703601522 0.8254781216972907 1.
52 0.5 0.7604011244310231 -0.28211213216592784 -0.3722846952738428 0.1581725581872408 0.8008522647497104 1.
53 0.51 0.7433440454962605 -0.2455540169176899 -0.3992980063927199 0.19300141807932156 0.7761561224913385 1.
54 0.52 0.7262590833969331 -0.20893614020926626 -0.42635547610418184 0.2194621842292243 0.751443124097109 1.
55 0.53 0.709058602701224 -0.17207067467417486 -0.453595892719742 0.2405673704012788 0.7265803324554873 1.
56 0.54 0.6915768892539101 -0.1346024482921609 -0.48128169789479536 0.25788347992973676 0.701321051230534 1.
57 0.55 0.6736331627810209 -0.09614399811510127 -0.5096991935104321 0.2722888922216317 0.6753950894563805 1.
58 0.56 0.6551463184003872 -0.05652149358027936 -0.5389768254408652 0.28422807900785235 0.6486730893521468 1.
59 0.57 0.6361671326276888 -0.01584376303510615 -0.5690341788729347 0.293907374075009 0.6212117649042732 1.
60 0.58 0.6168396823565967 0.025580396234342995 -0.5996430791016598 0.301442767979156 0.5931976878638505 1.
61 0.59 0.5973210287815495 0.06741435793529688 -0.6305547881733555 0.30694603901024253 0.5648312189065924 1.
62 0.6 0.5777303704171711 0.10940264614179468 -0.661580531294122 0.3105418468883679 0.5362525958007331 1.
63 0.61 0.5581475370499237 0.15137416317967575 -0.6925938819599547 0.3123531986526998 0.5075386530652202 1.
64 0.62 0.5386227795100639 0.19322120739317136 -0.7235152578861672 0.31248922600720636 0.4787151440558522 1.
65 0.63 0.5191666876024412 0.23492108185347996 -0.754327887989376 0.31103663081260624 0.44973844514160927 1.
66 0.64 0.4996990584326256 0.2766456839100268 -0.7851587896650079 0.30803814950244496 0.4204116611935119 1.
67 0.65 0.479957679121191 0.3189570094767831 -0.8164232296840259 0.30343473603466015 0.390226489453496 1.
68 0.66 0.4600072725872886 0.3617163391430824 -0.8480187063016573 0.29717122075330515 0.3591178757512998 1.
69 0.67 0.44600100870220305 0.4113853615984094 -0.8697728377551008 0.3178994129506999 0.3295740682997879 1.
70 0.68 0.4574651571354146 0.44026390446569547 -0.8504539292487465 0.3842479358768364 0.3280946443367561 1.
71 0.69 0.4691809168948424 0.46977626401045774 -0.830711015748157 0.44293649140770447 0.3260767554252525 1.
72 0.7 0.4811696900083858 0.49997635259991063 -0.8105080314416201 0.49708450874457527 0.3234487047238236 1.
73 0.71 0.49350094811609174 0.5310391714342613 -0.7897279055963483 0.5485591109413528 0.3201099534066949 1.
74 0.72 0.5062548753068121 0.5631667067020758 -0.7682355153041539 0.5985798481027601 0.3159263917472715 1.
75 0.73 0.5195243020949684 0.5965928013272943 -0.7458744264238399 0.6480500606439057 0.31071717884730565 1.
76 0.74 0.5334043922713477 0.6315571758288618 -0.7224842728734379 0.6976685401842261 0.3042411890803418 1.
77 0.75 0.5479805812358602 0.6682750446095802 -0.697921082452685 0.7479712773579563 0.29618040787504757 1.
78 0.76 0.5633244502526606 0.7069267230777347 -0.6720642293775535 0.7993701361353484 0.28611136999256687 1.
79 0.77 0.5794956601139 0.7476624986056212 -0.6448131757501174 0.8521918014427678 0.2734527325942473 1.
80 0.78 0.5965429098573916 0.7906050455688622 -0.6160858559672187 0.9067003897516911 0.2573693489198746 1.
81 0.79 0.6145761476424179 0.8360313267658297 -0.5856969899409387 0.963334644317004 0.23648492980159264 1.
82 0.8 0.6232910688128902 0.859291371252556 -0.5300995185388214 1. 0.21867949406239662 0.9712088595948508
83 0.81 0.6159984336377875 0.8439887543380684 -0.44635440435952856 1. 0.21606849746358275 0.9041480210597966
84 0.82 0.6091642745073532 0.8296481879180277 -0.36787420852419694 1. 0.21421830096504035 0.8419706002336461
85 0.83 0.6025478038652375 0.8157644115969636 -0.2918938425681935 1. 0.21295365915197917 0.7823908751330636
86 0.84 0.5961857222953111 0.8024144366282877 -0.21883475834162458 0.9971140114799418 0.21220068235083267 0.7256713129328118
87 0.85 0.5900921771070883 0.7896279492437488 -0.1488594167412921 0.993273906363258 0.2118788857127918 0.671860243327784
88 0.86 0.5842771639541229 0.7774259239818333 -0.08208260304413262 0.9887084084529413 0.21191070453347688 0.6209624706933893
89 0.87 0.578741582584259 0.7658102488427286 -0.018514649521559012 0.9835846378805114 0.2122246941077346 0.5728987835613306
90 0.88 0.5734741590353537 0.7547572669288056 0.04197390858426542 0.9780378159372328 0.21275878699579343 0.5274829957183049
91 0.89 0.5684517008574971 0.7442183119942206 0.09964940221121898 0.9721670725313721 0.21346242315895625 0.4844270603851604
92 0.9 0.5636419856510335 0.7341257696545772 0.15488185789614228 0.9660363209686843 0.21429691147008262 0.4433660148378527
93 0.91 0.5590069340453534 0.7243997354573974 0.20810856081277884 0.9596781387247791 0.2152344151262528 0.4038812338146013
94 0.92 0.5545051525321143 0.7149533506766244 0.25980485409830323 0.9530986696850675 0.21625626438013962 0.3655130449917989
95 0.93 0.5500961975299247 0.705701749880514 0.3104351723857584 0.9462863346513658 0.21735046958786286 0.327780364198278
96 0.94 0.545740378056064 0.6965616468647046 0.36045530782708896 0.93921469089265 0.21851014470332586 0.29014917175372823
97 0.95 0.5414004092067859 0.6874548042588865 0.41029342232076466 0.9318478255642132 0.21973168075163751 0.2519897371806688
98 0.96 0.5370416605957644 0.6783085548415655 0.46034719456417006 0.9241434776436454 0.22101341980094052 0.2124579038400577
99 0.97 0.5326309593934517 0.6690532898786764 0.5109975653738162 0.9160532016485884 0.22235495330179011 0.17018252385769012
100 0.98 0.5281374148557197 0.6596241892863608 0.5625992691950712 0.90752576202319 0.22375597459867458 0.1223073280126531
101 0.99 0.5235317096396147 0.6499597345521199 0.615488972291106 0.8985077346125597 0.22521565729028564 0.05933950582860665
102 1. 0.5187848173343539 0.6399990176455989 0.67 0.8889427469969852 0.22673227640012172 0.

View File

@@ -26,28 +26,39 @@ global_settings = GlobalSettings(
)
data_settings = DataSettings(
# config_path="data/*-128-16384-1-0-0-0-0-PAM4-0-0.ini",
config_path="data/20250110-190528-128-16384-100000-0-0.2-17.0-0.058-PAM4-0-0.14-10.ini",
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
# config_path="data/20250115-114319-128-16384-inf-100000-0-0-0-0-PAM4-1-0-10.ini", # no effects enabled - baseline
# config_path = "data/20250115-233553-128-16384-1060.0-100000-0-0.2-17.0-0.058-PAM4-1.0-0.0-10.ini", # dispersion + slope only
# config_path="data/20250115-115836-128-16384-60.0-100000-0-0.2-17-0.058-PAM4-1000-0.2-10.ini", # all linear effects enabled with realistic values + noise + pmd (delta_beta=0.2) + ortho_error = 0.1
# config_path="data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # a)
# config_path="data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # b)
# config_path="data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # c)
# config_path="data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # d) 10ps dgd
config_path="data/20250120-105720-128-16384-inf-100000-0-0.2-17-0.058-PAM4-0-0-10.ini", # d) 10ns
# config_path="data/20250114-215547-128-16384-60.0-100000-1.15-0.2-17-0.058-PAM4-1-0-10.ini", # with gamma=1.15, 2.5dBm launch power, no pmd
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=4, # study: single_core_regen_20241123_011232
symbols=4, # study: single_core_regen_20241123_011232 -> taps spread over 4 symbols @ 10GBd
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
output_size=20, # study: single_core_regen_20241123_011232 (model_input_dim/2)
shuffle=True,
drop_first=64,
output_size=20, # = model_input_dim/2 study: single_core_regen_20241123_011232
shuffle=False,
drop_first=256,
drop_last=256,
train_split=0.8,
randomise_polarisations=False,
polarisations=True,
polarisations=False,
# cross_pol_interference=0.01,
osnr=16, #16dB due to amplification with NF 5
)
pytorch_settings = PytorchSettings(
epochs=1000,
batchsize=2**14,
batchsize=2**13,
device="cuda",
dataloader_workers=24,
dataloader_prefetch=8,
dataloader_workers=32,
dataloader_prefetch=4,
summary_dir=".runs",
write_every=2**5,
save_models=True,
@@ -65,16 +76,13 @@ model_settings = ModelSettings(
# "n_hidden_nodes_3": 4,
# "n_hidden_nodes_4": 2,
},
model_activation_func="phase_shift",
model_activation_func="EOActivation",
dropout_prob=0,
model_layer_function="ONNRect",
model_layer_kwargs={"square": True},
scale=2.0,
model_layer_parametrizations=[
{
"tensor_name": "weight",
"parametrization": util.complexNN.energy_conserving,
},
# EOactivation
{
"tensor_name": "alpha",
"parametrization": util.complexNN.clamp,
@@ -83,54 +91,20 @@ model_settings = ModelSettings(
"max": 1,
},
},
# ONNRect
{
"tensor_name": "gain",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": None,
},
},
{
"tensor_name": "phase_bias",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2 * torch.pi,
},
"tensor_name": "weight",
"parametrization": torch.nn.utils.parametrizations.orthogonal,
},
# Scale
{
"tensor_name": "scale",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2,
"max": 10,
},
},
{
"tensor_name": "angle",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": -torch.pi,
"max": torch.pi,
},
},
# {
# "tensor_name": "scale",
# "parametrization": util.complexNN.clamp,
# },
# {
# "tensor_name": "bias",
# "parametrization": util.complexNN.clamp,
# },
# {
# "tensor_name": "V",
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
# },
{
"tensor_name": "loss",
"parametrization": util.complexNN.clamp,
},
}
],
)
@@ -145,107 +119,19 @@ optimizer_settings = OptimizerSettings(
scheduler="ReduceLROnPlateau",
scheduler_kwargs={
"patience": 2**6,
"factor": 0.75,
"factor": 0.5,
# "threshold": 1e-3,
"min_lr": 1e-6,
"cooldown": 10,
},
early_stopping=True,
early_stop_kwargs={
"threshold": 1e-06,
"plateau": 2**7,
}
)
def save_dict_to_file(dictionary, filename):
"""
Save the best dictionary to a JSON file.
:param best: Dictionary containing the best training results.
:type best: dict
:param filename: Path to the JSON file where the dictionary will be saved.
:type filename: str
"""
with open(filename, "w") as f:
json.dump(dictionary, f, indent=4)
def sweep_lengths(*lengths, model=None, data_glob: str = None, strategy="newest"):
assert model is not None, "Model must be provided."
assert data_glob is not None, "Data glob must be provided."
model = model
fiber_ins = {}
fiber_outs = {}
regens = {}
timestampss = {}
trainer = RegenerationTrainer(
checkpoint_path=model,
)
trainer.define_model()
for length in lengths:
data_glob_length = data_glob.replace("{length}", str(length))
files = list(Path.cwd().glob(data_glob_length))
if len(files) == 0:
continue
if strategy == "newest":
sorted_kwargs = {
"key": lambda x: x.stat().st_mtime,
"reverse": True,
}
elif strategy == "oldest":
sorted_kwargs = {
"key": lambda x: x.stat().st_mtime,
"reverse": False,
}
else:
raise ValueError(f"Unknown strategy {strategy}.")
file = sorted(files, **sorted_kwargs)[0]
loader, _ = trainer.get_sliced_data(override={"config_path": file})
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
fiber_ins[length] = fiber_in
fiber_outs[length] = fiber_out
regens[length] = regen
timestampss[length] = timestamps
data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0])
channel_names = ["" for _ in range(2 * len(timestampss.keys()) + 2)]
data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128
data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square()
channel_names[1] = "fiber in x"
for li, length in enumerate(timestampss.keys()):
data[2 + 2 * li, 0, :] = timestampss[length] / 128
data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
data[2 + 2 * li + 1, 0, :] = timestampss[length] / 128
data[2 + 2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
channel_names[2 + 2 * li + 1] = f"regen x {length}"
channel_names[2 + 2 * li] = f"fiber out x {length}"
# get current backend
backend = matplotlib.get_backend()
matplotlib.use("TkCairo")
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
print_attrs = ("channel_name", "success", "min_area")
with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}):
for result in eye.eye_stats:
print_dict = {attr: result[attr] for attr in print_attrs}
rprint(print_dict)
rprint()
eye.plot(all_stats=False)
matplotlib.use(backend)
if __name__ == "__main__":
# lengths = range(90000, 100000+10000, 10000)
# lengths = [100000]
# sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest")
trainer = RegenerationTrainer(
global_settings=global_settings,
@@ -253,83 +139,15 @@ if __name__ == "__main__":
pytorch_settings=pytorch_settings,
model_settings=model_settings,
optimizer_settings=optimizer_settings,
# checkpoint_path=".models/best_20250104_191428.tar",
reset_epoch=True,
# settings_override={
# "data_settings": {
# "config_path": "data/20241229-163*-128-16384-100000-*.ini",
# "polarisations": True,
# },
# "model_settings": {
# "scale": 2.0,
# }
# }
checkpoint_path=".models/best_20250117_144001.tar",
new_model=True,
settings_override={
"data_settings": data_settings.__dict__,
# "optimizer_settings": {
# "optimizer_kwargs": {
# "lr": 0.01,
# },
# "early_stop_kwargs":{
# "plateau": 2**8,
# }
# }
# }
# 20241202_143149
}
)
trainer.train()
# from hypertraining.lighning_models import regenerator, regeneratorData
# import lightning as L
# model = regenerator(
# 2 * data_settings.output_size,
# *model_settings.overrides["hidden_layer_dims"],
# model_settings.output_dim,
# layer_function=getattr(util.complexNN, model_settings.model_layer_function),
# layer_func_kwargs=model_settings.model_layer_kwargs,
# act_function=getattr(util.complexNN, model_settings.model_activation_func),
# act_func_kwargs=None,
# parametrizations=model_settings.model_layer_parametrizations,
# dtype=getattr(torch, data_settings.dtype),
# dropout_prob=model_settings.dropout_prob,
# scale_layers=model_settings.scale,
# optimizer=getattr(torch.optim, optimizer_settings.optimizer),
# optimizer_kwargs=optimizer_settings.optimizer_kwargs,
# lr_scheduler=getattr(torch.optim.lr_scheduler, optimizer_settings.scheduler),
# lr_scheduler_kwargs=optimizer_settings.scheduler_kwargs,
# )
# dm = regeneratorData(
# config_globs=data_settings.config_path,
# output_symbols=data_settings.symbols,
# output_dim=data_settings.output_size,
# dtype=getattr(torch, data_settings.dtype),
# drop_first=data_settings.drop_first,
# shuffle=data_settings.shuffle,
# train_split=data_settings.train_split,
# batch_size=pytorch_settings.batchsize,
# loader_settings={
# "num_workers": pytorch_settings.dataloader_workers,
# "prefetch_factor": pytorch_settings.dataloader_prefetch,
# "pin_memory": True,
# "drop_last": True,
# },
# seed=global_settings.seed,
# )
# # writer = L.SummaryWriter(pytorch_settings.summary_dir + f"/{datetime.now().strftime('%Y%m%d_%H%M%S')}")
# # from torch.utils.tensorboard import SummaryWriter
# subdir = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# # writer = SummaryWriter(pytorch_settings.summary_dir + f"/{subdir}")
# logger = L.pytorch.loggers.TensorBoardLogger(pytorch_settings.summary_dir, name=subdir, log_graph=True)
# trainer = L.Trainer(
# fast_dev_run=False,
# # max_epochs=pytorch_settings.epochs,
# max_epochs=2,
# enable_checkpointing=True,
# default_root_dir=f".models/{subdir}/",
# logger=logger,
# )
# trainer.fit(model, dm)

View File

@@ -12,6 +12,7 @@ Full license text in LICENSE file
"""
import configparser
# import copy
from datetime import datetime
import hashlib
from pathlib import Path
@@ -40,7 +41,7 @@ alpha = 0.2
D = 17
S = 0.058
bireflength = 10
max_delta_beta = 0.14
pmd_q = 0.2
; birefseed = 0xC0FFEE
[signal]
@@ -195,10 +196,14 @@ class pam_generator:
def initialize_fiber_and_data(config):
f0 = config["glova"].get("f0", None)
if f0 is None:
f0 = 299792458/(config["glova"].get("lambda0", 1550)*1e-9)
config["glova"]["f0"] = f0
py_glova = pypho.setup(
nos=config["glova"]["nos"],
sps=config["glova"]["sps"],
f0=config["glova"]["f0"],
f0=f0,
symbolrate=config["glova"]["symbolrate"],
wisdom_dir=config["glova"]["wisdom_dir"],
flags=config["glova"]["flags"],
@@ -216,7 +221,9 @@ def initialize_fiber_and_data(config):
symbolsrc = pypho.symbols(
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
)
laser = pypho.lasmod(py_glova, power=config["signal"]["laser_power"], Df=0, theta=np.pi / 4)
laserx = pypho.lasmod(py_glova, power=0, Df=0, theta=np.pi/4)
# lasery = pypho.lasmod(py_glova, power=0, Df=25, theta=0)
modulator = pam_generator(
py_glova,
mod_depth=config["signal"]["mod_depth"],
@@ -232,7 +239,12 @@ def initialize_fiber_and_data(config):
symbols_y[:3] = 0
# symbols_x += 1
cw = laser()
cw = laserx()
# cwy = lasery()
# cw[0]['E'][0] = cw[0]['E'][0]
# cw[0]['E'][1] = cwy[0]['E'][0]
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
@@ -251,13 +263,41 @@ def initialize_fiber_and_data(config):
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
## side channels
# df = 100
# signal_power = pypho.functions.W_to_dBm(np.sum(pypho.functions.getpower_W(source_signal[0]["E"])))
# symbols_x_side = symbolsrc(pattern="random")
# symbols_y_side = symbolsrc(pattern="random")
# symbols_x_side[:3] = 0
# symbols_y_side[:3] = 0
# cw_left = laser(Df=-df)
# source_signal_left = modulator(E=cw_left, symbols=(symbols_x_side, symbols_y_side))
# cw_right = laser(Df=df)
# source_signal_right = modulator(E=cw_right, symbols=(symbols_y_side, symbols_x_side))
E_in_pure = source_signal[0]["E"]
nf = py_edfa.NF
source_signal = py_edfa(E=source_signal, NF=0)
py_edfa.NF = nf
pmean = py_edfa.Pmean
# ideal amplification to launch power into fiber
source_signal = py_edfa(E=source_signal, NF=0, Pmean=config["signal"]["laser_power"])
# source_signal_left = py_edfa(E=source_signal_left, NF=0, Pmean=config["signal"]["laser_power"])
# source_signal_right = py_edfa(E=source_signal_right, NF=0, Pmean=config["signal"]["laser_power"])
# source_signal[0]["E"][0] += source_signal_left[0]["E"][0] + source_signal_right[0]["E"][0]
# source_signal[0]["E"][1] += source_signal_left[0]["E"][1] + source_signal_right[0]["E"][1]
c_data.E_in = source_signal[0]["E"]
noise = source_signal[0]["noise"]
py_edfa.NF = nf
py_edfa.Pmean = pmean
py_fiber = pypho.fiber(
glova=py_glova,
l=config["fiber"]["length"],
@@ -265,20 +305,29 @@ def initialize_fiber_and_data(config):
gamma=config["fiber"]["gamma"],
D=config["fiber"]["d"],
S=config["fiber"]["s"],
phi_max=0.02,
)
if config["fiber"].get("birefsteps", 0) > 0:
config["fiber"]["birefsteps"] = config["fiber"].get(
"birefsteps", config["fiber"]["length"] // config["fiber"].get("bireflength", config["fiber"]["length"])
)
if config["fiber"]["birefsteps"] > 0:
config["fiber"]["bireflength"] = config["fiber"].get("bireflength", config["fiber"]["length"] / config["fiber"]["birefsteps"])
seed = config["fiber"].get("birefseed", (int(time.time() * 1000)) % 2**32)
py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre(
py_fiber.l,
py_fiber.l / config["fiber"]["birefsteps"],
# maxDeltaD=config["fiber"]["d"]/5,
maxDeltaBeta=config["fiber"].get("max_delta_beta", 0),
config["fiber"]["length"],
config["fiber"]["bireflength"],
maxDeltaBeta=config["fiber"].get("pmd_q", 0)/np.sqrt(2*config["fiber"]["bireflength"]),
seed=seed,
)
c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200)
elif (dgd := config['fiber'].get('dgd', 0)) > 0:
py_fiber.birefarray = [
pypho.birefringence_segment(z_point=0, angle=np.pi/2, delta_beta=1000*dgd/config["fiber"]["length"])
]
c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=config["fiber"]["length"] if py_fiber.gamma == 0 else 200)
c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova)
return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y)
return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y), py_glova, E_in_pure
def save_data(data, config, **metadata):
@@ -316,8 +365,11 @@ def save_data(data, config, **metadata):
f"D = {config['fiber']['d']}",
f"S = {config['fiber']['s']}",
f"birefsteps = {config['fiber'].get('birefsteps', 0)}",
f"max_delta_beta = {config['fiber'].get('max_delta_beta', 0)}",
f"pmd_q = {config['fiber'].get('pmd_q', 0)}",
f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set",
f"dgd = {config['fiber'].get('dgd', 0)}",
f"ortho_error = {config['fiber'].get('ortho_error', 0)}",
f"pol_error = {config['fiber'].get('pol_error', 0)}",
"",
"[signal]",
f"seed = {hex(seed)}" if seed else "; seed = not set",
@@ -346,24 +398,12 @@ def save_data(data, config, **metadata):
save_file = f"{config_hash}.h5"
config_content += f'"{str(save_file)}"\n'
filename_components = (
timestamp.strftime("%Y%m%d-%H%M%S"),
config["glova"]["sps"],
config["glova"]["nos"],
config["signal"]["osnr"],
config["fiber"]["length"],
config["fiber"]["gamma"],
config["fiber"]["alpha"],
config["fiber"]["d"],
config["fiber"]["s"],
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
config["fiber"].get("birefsteps", 0),
config["fiber"].get("max_delta_beta", 0),
int(config["glova"]["symbolrate"] / 1e9),
)
config_filename:Path = create_config_filename(config, data_dir, timestamp)
while config_filename.exists():
time.sleep(1)
config_filename = create_config_filename(config, data_dir=data_dir)
lookup_file = "-".join(map(str, filename_components)) + ".ini"
config_filename = data_dir / lookup_file
with open(config_filename, "w") as f:
f.write(config_content)
@@ -376,11 +416,31 @@ def save_data(data, config, **metadata):
outfile.attrs[key] = value
# np.save(save_dir / save_file, save_data)
print("Saved config to", config_filename)
print("Saved data to", save_dir / save_file)
# print("Saved config to", config_filename)
# print("Saved data to", save_dir / save_file)
return config_filename
def create_config_filename(config, data_dir:Path, timestamp=None):
if timestamp is None:
timestamp = datetime.now()
filename_components = (
timestamp.strftime("%Y%m%d-%H%M%S"),
config["glova"]["sps"],
config["glova"]["nos"],
config["signal"]["osnr"],
config["fiber"]["length"],
config["fiber"]["gamma"],
config["fiber"]["alpha"],
config["fiber"]["d"],
config["fiber"]["s"],
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
config["fiber"].get("birefsteps", 0),
config["fiber"].get("pmd_q", 0),
int(config["glova"]["symbolrate"] / 1e9),
)
lookup_file = "-".join(map(str, filename_components)) + ".ini"
return data_dir / lookup_file
def length_loop(config, lengths, save=True):
lengths = sorted(lengths)
@@ -388,7 +448,7 @@ def length_loop(config, lengths, save=True):
print(f"\nGenerating data for fiber length {length}m")
config["fiber"]["length"] = length
cfiber, cdata, noise, edfa = initialize_fiber_and_data(config)
cfiber, cdata, noise, edfa, symbols, py_glova = initialize_fiber_and_data(config)
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
cfiber()
@@ -416,51 +476,49 @@ def single_run_with_plot(config, save=True):
in_out_eyes(cfiber, cdata, show_pols=False)
return config_filename
def single_run(config, save=True):
cfiber, cdata, noise, edfa, symbols = initialize_fiber_and_data(config)
# mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
# print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)")
# estimate osnr
# noise_power = np.mean(noise)
# osnr_lin = mean_power_in / noise_power - 1
# osnr = 10 * np.log10(osnr_lin)
# print(f"Estimated OSNR: {osnr:.3f} dB")
def single_run(config, save=True, silent=True):
cfiber, cdata, noise, edfa, symbols, glova, E_in = initialize_fiber_and_data(config)
# transmit
cfiber()
# mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
# print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)")
# noise = noise * np.exp(-cfiber.params.l * cfiber.params.alpha)
# estimate osnr
# noise_power = np.mean(noise)
# osnr_lin = mean_power_out / noise_power - 1
# osnr = 10 * np.log10(osnr_lin)
# print(f"Estimated OSNR: {osnr:.3f} dB")
# amplify
E_tmp = [{"E": cdata.E_out, "noise": noise}]
E_tmp = edfa(E=E_tmp)
# rotate
# ortho error
ortho_error = config["fiber"].get("ortho_error", 0)
E_tmp[0]["E"] = np.stack((
E_tmp[0]["E"][0] * np.cos(ortho_error/2) + E_tmp[0]["E"][1] * np.sin(ortho_error/2),
E_tmp[0]["E"][0] * np.sin(ortho_error/2) + E_tmp[0]["E"][1] * np.cos(ortho_error/2)
), axis=0)
pol_error = config['fiber'].get('pol_error', 0)
E_tmp[0]["E"] = np.stack((
E_tmp[0]["E"][0] * np.cos(pol_error) - E_tmp[0]["E"][1] * np.sin(pol_error),
E_tmp[0]["E"][0] * np.sin(pol_error) + E_tmp[0]["E"][1] * np.cos(pol_error)
), axis=0)
# output
cdata.E_out = E_tmp[0]["E"]
# noise = E_tmp[0]["noise"]
# mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out))
# print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)")
# estimate osnr
# noise_power = np.mean(noise)
# osnr_lin = mean_power_amp / noise_power - 1
# osnr = 10 * np.log10(osnr_lin)
# print(f"Estimated OSNR: {osnr:.3f} dB")
config_filename = None
symbols = np.array(symbols)
if save:
config_filename = save_data(cdata, config, **{"symbols": symbols})
return cfiber,cdata,config_filename
if not silent:
print(f"Saved config to {config_filename}")
return cfiber, cdata, config_filename
def in_out_eyes(cfiber, cdata, show_pols=False):

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -441,8 +441,7 @@ class input_rotator(nn.Module):
# return out
#### as defined by zhang et al
#### as defined by zhang et alas
class DropoutComplex(nn.Module):
def __init__(self, p=0.5):
@@ -464,7 +463,7 @@ class Scale(nn.Module):
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
def forward(self, x):
return x * self.scale
return x * torch.sqrt(self.scale)
def __repr__(self):
return f"Scale({self.size})"
@@ -546,35 +545,31 @@ class EOActivation(nn.Module):
raise ValueError("Size must be specified")
self.size = size
self.alpha = nn.Parameter(torch.rand(size))
self.V_bias = nn.Parameter(torch.rand(size))
self.gain = nn.Parameter(torch.rand(size))
# if bias:
# self.phase_bias = nn.Parameter(torch.zeros(size))
# else:
# self.register_buffer("phase_bias", torch.zeros(size))
# self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
self.register_buffer("responsivity", torch.ones(size)*0.9)
self.register_buffer("V_pi", torch.ones(size)*3)
self.V_bias = nn.Parameter(torch.rand(size))
# self.register_buffer("gain", torch.ones(size))
# self.register_buffer("responsivity", torch.ones(size))
# self.register_buffer("V_pi", torch.ones(size))
self.reset_weights()
def reset_weights(self):
if "alpha" in self._parameters:
self.alpha.data = torch.rand(self.size)
if "V_pi" in self._parameters:
self.V_pi.data = torch.rand(self.size)*3
# if "V_pi" in self._parameters:
# self.V_pi.data = torch.rand(self.size)*3
if "V_bias" in self._parameters:
self.V_bias.data = torch.randn(self.size)
if "gain" in self._parameters:
self.gain.data = torch.rand(self.size)
if "responsivity" in self._parameters:
self.responsivity.data = torch.ones(self.size)*0.9
# if "responsivity" in self._parameters:
# self.responsivity.data = torch.ones(self.size)*0.9
# if "bias" in self._parameters:
# self.phase_bias.data = torch.zeros(self.size)
def forward(self, x: torch.Tensor):
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8)
phi_b = torch.pi * self.V_bias# / (self.V_pi)
g_phi = torch.pi * (self.alpha * self.gain)# * self.responsivity)# / (self.V_pi)
intermediate = g_phi * x.abs().square() + phi_b
return (
1j

View File

@@ -0,0 +1,105 @@
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
# This software is licensed according to the "BSD 2-clause" license.
import hashlib
import h5py
import numpy as _np
from scipy.interpolate import interp1d as _interp1d
from scipy.ndimage import gaussian_filter as _gaussian_filter
from ._brescount import bres_curve_count as _bres_curve_count
from pathlib import Path
__all__ = ['grid_count']
def grid_count(y, window_size, offset=0, size=None, fuzz=True, blur=0, bounds=None):
"""
Parameters
----------
`y` is the 1-d array of signal samples.
`window_size` is the number of samples to show horizontally in the
eye diagram. Typically this is twice the number of samples in a
"symbol" (i.e. in a data bit).
`offset` is the number of initial samples to skip before computing
the eye diagram. This allows the overall phase of the diagram to
be adjusted.
`size` must be a tuple of two integers. It sets the size of the
array of counts, (height, width). The default is (800, 640).
`fuzz`: If True, the values in `y` are reinterpolated with a
random "fuzz factor" before plotting in the eye diagram. This
reduces an aliasing-like effect that arises with the use of
Bresenham's algorithm.
`bounds` must be a tuple of two floating point values, (ymin, ymax).
These set the y range of the returned array. If not given, the
bounds are `(y.min() - 0.05*A, y.max() + 0.05*A)`, where `A` is
`y.max() - y.min()`.
Return Value
------------
Returns a numpy array of integers.
"""
# hash input params
param_ob = (y, window_size, offset, size, fuzz, blur, bounds)
param_hash = hashlib.md5(str(param_ob).encode()).hexdigest()
cache_dir = Path.home()/".eyediagram"/".cache"
cache_dir.mkdir(parents=True, exist_ok=True)
if (cache_dir/param_hash).is_file():
try:
with h5py.File(cache_dir/param_hash, "r") as infile:
counts = infile["counts"][:]
if counts.len() != 0:
return counts
except:
pass
if size is None:
size = (800, 640)
height, width = size
dt = width / window_size
counts = _np.zeros((width, height), dtype=_np.int32)
if bounds is None:
ymin = y.min()
ymax = y.max()
yamp = ymax - ymin
ymin = ymin - 0.05*yamp
ymax = ymax + 0.05*yamp
ymax = _np.ceil(ymax*10)/10
ymin = _np.floor(ymin*10)/10
else:
ymin, ymax = bounds
start = offset
while start + window_size < len(y):
end = start + window_size
yy = y[start:end+1]
k = _np.arange(len(yy))
xx = dt*k
if fuzz:
f = _interp1d(xx, yy, kind='cubic')
jiggle = dt*(_np.random.beta(a=3, b=3, size=len(xx)-2) - 0.5)
xx[1:-1] += jiggle
yd = f(xx)
else:
yd = yy
iyd = (height * (yd - ymin)/(ymax - ymin)).astype(_np.int32)
_bres_curve_count(xx.astype(_np.int32), iyd, counts)
start = end
if blur != 0:
counts = _gaussian_filter(counts, sigma=blur)
with h5py.File(cache_dir/param_hash, "w") as outfile:
outfile.create_dataset("data", data=counts)
return counts

View File

@@ -25,13 +25,14 @@ import multiprocessing as mp
# def __len__(self):
# return len(self.indices)
def load_from_file(datapath):
if str(datapath).endswith('.h5'):
if str(datapath).endswith(".h5"):
symbols = None
with h5py.File(datapath, "r") as infile:
data = infile["data"][:]
try:
symbols = infile["symbols"][:]
symbols = np.swapaxes(infile["symbols"][:], 0, 1)
except KeyError:
pass
else:
@@ -40,7 +41,7 @@ def load_from_file(datapath):
return data, symbols
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=1, device=None, dtype=None):
def load_data(config_path, skipfirst=0, skiplast=0, symbols=None, real=False, normalize=1, device=None, dtype=None):
filepath = Path(config_path)
filepath = filepath.parent.glob(filepath.name)
config = configparser.ConfigParser()
@@ -58,12 +59,20 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=1, d
data, orig_symbols = load_from_file(datapath)
data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)]
orig_symbols = orig_symbols[skipfirst:symbols+skipfirst]
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps))
data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps - skiplast * sps)]
orig_symbols = orig_symbols[skipfirst : symbols + skipfirst - skiplast]
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps - skiplast * sps))
data *= np.sqrt(normalize)
launch_power = float(config["signal"]["laser_power"])
output_power = float(config["signal"]["edfa_power"])
target_normalization = 10 ** (output_power / 10) / 10 ** (launch_power / 10)
# target_normalization *= 0.5 # allow 50% power loss, so the network can ignore parts of the signal
data[:, 0:2] *= np.sqrt(target_normalization)
# if normalize:
# # square gets normalized to 1, as the power is (proportional to) the square of the amplitude
# a, b, c, d = data.T
@@ -132,13 +141,15 @@ class FiberRegenerationDataset(Dataset):
target_delay: float | int = 0,
xy_delay: float | int = 0,
drop_first: float | int = 0,
drop_last=0,
dtype: torch.dtype = None,
real: bool = False,
device=None,
# osnr: float|None = None,
polarisations = None,
polarisations=None,
randomise_polarisations: bool = False,
repeat_randoms: int = 1,
# cross_pol_interference: float = 0,
**kwargs,
):
"""
@@ -172,6 +183,7 @@ class FiberRegenerationDataset(Dataset):
assert drop_first >= 0, "drop_first must be non-negative"
self.randomise_polarisations = randomise_polarisations
# self.cross_pol_interference = cross_pol_interference
data_raw = None
self.config = None
@@ -181,6 +193,7 @@ class FiberRegenerationDataset(Dataset):
data, config, orig_syms = load_data(
file_path,
skipfirst=drop_first,
skiplast=drop_last,
symbols=kwargs.get("num_symbols", None),
real=real,
normalize=1000,
@@ -300,20 +313,18 @@ class FiberRegenerationDataset(Dataset):
# fiber_out: [E_out_x, E_out_y, timestamps]
# add noise related to amplification necessary due to splitting of the signal
gain_lin = output_dim*2
edfa_nf = float(self.config["signal"]["edfa_nf"])
nf_lin = 10**(edfa_nf/10)
f0 = float(self.config["glova"]["f0"])
noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9
noise = torch.randn_like(fiber_out[:2, :])
noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1))
noise = noise * torch.sqrt(noise_add / noise_power)
fiber_out[:2, :] += noise
# gain_lin = output_dim*2
# gain_lin = 1
# edfa_nf = float(self.config["signal"]["edfa_nf"])
# nf_lin = 10**(edfa_nf/10)
# f0 = float(self.config["glova"]["f0"])
# noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9
# noise = torch.randn_like(fiber_out[:2, :])
# noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1))
# noise = noise * torch.sqrt(noise_add / noise_power)
# fiber_out[:2, :] += noise
# if osnr is None:
# noisy = fiber_out[:2, :]
@@ -324,7 +335,6 @@ class FiberRegenerationDataset(Dataset):
# fiber_out: [E_out_x, E_out_y, timestamps, E_out_x_noisy, E_out_y_noisy]
if repeat_randoms > 1:
fiber_in = fiber_in.repeat(1, 1, repeat_randoms)
fiber_out = fiber_out.repeat(1, 1, repeat_randoms)
@@ -334,8 +344,9 @@ class FiberRegenerationDataset(Dataset):
if self.randomise_polarisations:
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms, device=fiber_out.device), 2) * torch.pi
# start_angle = torch.rand(1) * 2 * torch.pi
# angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
start_angle = torch.rand(1) * 2 * torch.pi
angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
angles = torch.randn(data_raw.shape[-1], device=fiber_out.device) * 2*torch.pi / 36 # sigma = 10 degrees
# self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi
else:
angles = torch.zeros(data_raw.shape[-1], device=fiber_out.device)
@@ -361,8 +372,6 @@ class FiberRegenerationDataset(Dataset):
# 4 E_out_y_rot,
# 5 angle
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
# data layout
# [ [E_in_x, E_in_y, timestamps],
@@ -374,9 +383,12 @@ class FiberRegenerationDataset(Dataset):
self.fiber_out = fiber_out.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.fiber_out = self.fiber_out.movedim(-2, 0)
# if self.randomise_polarisations:
# self.angles = torch.cumsum((torch.rand(self.fiber_out.shape[0]) - 0.5) * 2 * torch.pi * 2 / 5000, dim=0)
# self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
# self.data = self.data.movedim(-2, 0)
# self.angles = torch.zeros(self.data.shape[0])
# self.angles = torch.zeros(self.data.shape[0])
...
# ...
# -> [no_slices, 2, 3, samples_per_slice]
@@ -392,12 +404,12 @@ class FiberRegenerationDataset(Dataset):
return self.fiber_in.shape[0]
def add_noise(self, data, osnr):
osnr_lin = 10**(osnr/10)
osnr_lin = 10 ** (osnr / 10)
popt = torch.mean(data.abs().square().squeeze(), dim=-1)
noise = torch.randn_like(data)
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
mult = torch.sqrt(popt/(pn*osnr_lin))
mult = torch.sqrt(popt / (pn * osnr_lin))
mult = mult * torch.eye(popt.shape[0], device=mult.device)
mult = mult.to(dtype=noise.dtype)
@@ -406,7 +418,6 @@ class FiberRegenerationDataset(Dataset):
noisy = data + noise
return noisy
def __getitem__(self, idx):
if isinstance(idx, slice):
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
@@ -418,6 +429,10 @@ class FiberRegenerationDataset(Dataset):
output_dim = self.output_dim // 2
self.output_dim = output_dim * 2
if not self.polarisations:
output_dim = 2 * output_dim
fiber_in = self.fiber_in[idx].squeeze()
fiber_out = self.fiber_out[idx].squeeze()
@@ -427,85 +442,35 @@ class FiberRegenerationDataset(Dataset):
fiber_in = fiber_in.view(fiber_in.shape[0], output_dim, -1)
fiber_out = fiber_out.view(fiber_out.shape[0], output_dim, -1)
# data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
# data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
# angle = self.angles[idx]
# fiber_in:
# 0 E_in_x,
# 1 E_in_y,
# 2 timestamps
# fiber_out:
# 0 E_out_x,
# 1 E_out_y,
# 2 timestamps,
# 3 E_out_x_rot,
# 4 E_out_y_rot,
# 5 angle
center_angle = fiber_out[0, output_dim // 2, 0]
center_angle = fiber_out[5, output_dim // 2, 0]
angles = fiber_out[5, :, 0]
plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone()
plot_data_rot = fiber_out[3:5, output_dim // 2, 0].detach().clone()
data = fiber_out[0:2, :, 0]
# fiber_out_plot_clean = fiber_out[:2, output_dim // 2, 0].detach().clone()
plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone()
# if self.polarisations:
# rot = int(np.random.randint(2)*2-1)
# pol_flipped_data[0:1, :] = rot*data[0, :]
# pol_flipped_data[1, :] = rot*data[1, :]
# plot_data_rot[0] = rot*plot_data_rot[0]
# plot_data_rot[1] = rot*plot_data_rot[1]
# center_angle = center_angle + (rot - 1) * torch.pi/2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
# angles = angles + (rot - 1) * torch.pi/2
# if self.randomise_polarisations:
# data = data.mT
# c = torch.cos(angle).unsqueeze(-1)
# s = torch.sin(angle).unsqueeze(-1)
# rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1)
# data = torch.bmm(data.mT.unsqueeze(0), rot.to(dtype=data.dtype)).squeeze(-1)
...
# angle = torch.zeros_like(angle)
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter)
# angle_data = fiber_out[:2, :, :].reshape(2, -1).mean(dim=1).repeat(1, output_dim)
# angle_data2 = self.complex_max(fiber_out[:2, :, :].reshape(2, -1)).repeat(1, output_dim)
# sop = self.polarimeter(plot_data)
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
# angle = data_slice[1, 3, self.output_dim // 2, 0].real
target = fiber_in[:2, output_dim // 2, 0]
plot_target = fiber_in[:2, output_dim // 2, 0].detach().clone()
target_timestamp = fiber_in[2, output_dim // 2, 0].real
...
if self.polarisations:
rot = int(np.random.randint(2)*2-1)
data = rot*data
target = rot*target
plot_data_rot = rot*plot_data_rot
center_angle = center_angle + (rot - 1) * torch.pi/2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
angles = angles + (rot - 1) * torch.pi/2
rot = int(np.random.randint(2) * 2 - 1)
data = rot * data
target = rot * target
plot_data_rot = rot * plot_data_rot
center_angle = center_angle + (rot - 1) * torch.pi / 2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
angles = angles + (rot - 1) * torch.pi / 2
pol_flipped_data = -data
pol_flipped_target = -target
# data_timestamps = data[-1,:].real
# data = data[:-1, :]
# target_timestamp = target[-1].real
# target = target[:-1]
# plot_data = plot_data[:-1]
# transpose to interleave the x and y data in the output tensor
data = data.transpose(0, 1).flatten().squeeze()
data = data / torch.sqrt(torch.ones(1) * len(data)) # power loss due to splitting
pol_flipped_data = pol_flipped_data.transpose(0, 1).flatten().squeeze()
pol_flipped_data = pol_flipped_data / torch.sqrt(
torch.ones(1) * len(pol_flipped_data)
) # power loss due to splitting
# angle_data = angle_data.transpose(0, 1).flatten().squeeze()
# angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
center_angle = center_angle.flatten().squeeze()
@@ -526,8 +491,8 @@ class FiberRegenerationDataset(Dataset):
"y": target,
"y_flipped": pol_flipped_target,
"y_stacked": torch.cat([target, pol_flipped_target], dim=-1),
# "center_angle": center_angle,
# "angles": angles,
"center_angle": center_angle,
"angles": angles,
"mean_angle": angles.mean(),
# "sop": sop,
# "angle_data": angle_data,

View File

@@ -1,16 +1,23 @@
from datetime import datetime
import json
from pathlib import Path
from typing import Literal
import h5py
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
# from cmap import Colormap as cm
import numpy as np
from scipy.cluster.vq import kmeans2
import warnings
import multiprocessing
from rich.traceback import install
from rich import pretty
from rich import print
install()
pretty.install()
# from rich import pretty
# from rich import print
# pretty.install()
def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
@@ -21,6 +28,7 @@ def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
xaxis = np.arange(0, len(signal)) / sps
return np.vstack([xaxis, signal])
def create_symbol_sequence(n_symbols, skew=1):
np.random.seed(42)
data = np.random.randint(0, 4, n_symbols) / 4
@@ -39,6 +47,14 @@ def generate_signal(data, sps):
signal = np.convolve(data_padded, wavelet)
signal = np.cumsum(signal)
signal = signal[sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
mi, ma = np.min(signal), np.max(signal)
signal = (signal - mi) / (ma - mi)
mod = 0.8
signal *= mod
signal += 1 - mod
return signal
@@ -49,8 +65,8 @@ def normalization_with_noise(signal, noise=0):
signal += awgn
# min-max normalization
signal = signal - np.min(signal)
signal = signal / np.max(signal)
# signal = signal - np.min(signal)
# signal = signal / np.max(signal)
return signal
@@ -68,84 +84,248 @@ def generate_wavelet(sps, oversample=3):
class eye_diagram:
def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4, multithreaded=True):
def __init__(
self,
data,
*,
channel_names=None,
horizontal_bins=256,
vertical_bins=1000,
n_levels=4,
multithreaded=True,
save_file_or_dir=None,
):
# data has shape [channels, 2, samples]
# each sample has a timestamp and a value
if data.ndim == 2:
data = data[np.newaxis, :, :]
self.channel_names = channel_names
self.raw_data = data
self.channels = data.shape[0]
self.y_bins = np.zeros(1)
self.x_bins = np.zeros(1)
self.eye_data = np.zeros(1)
self.channel_names = channel_names
self.n_channels = data.shape[0]
self.n_levels = n_levels
self.eye_stats = [{"success": False} for _ in range(self.channels)]
self.eye_stats = [{"success": False} for _ in range(self.n_channels)]
self.horizontal_bins = horizontal_bins
self.vertical_bins = vertical_bins
self.multi_threaded = multithreaded
self.analysed = False
self.eye_built = False
def generate_eye_data(self):
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
self.y_bins = np.zeros((self.channels, self.vertical_bins))
self.eye_data = np.zeros((self.channels, self.vertical_bins, self.horizontal_bins))
datas = [self.raw_data[i] for i in range(self.channels)]
if self.multi_threaded:
with multiprocessing.Pool() as pool:
results = pool.map(self.generate_eye_data_single, datas)
for i, result in enumerate(results):
self.eye_data[i], self.y_bins[i] = result
self.save_file = save_file_or_dir
def load_data(self, file=None):
file = self.save_file if file is None else file
if file is None:
raise FileNotFoundError("No file specified.")
self.save_file = str(file)
# self.file_or_dir = self.save_file
with h5py.File(file, "r") as infile:
self.y_bins = infile["y_bins"][:]
self.x_bins = infile["x_bins"][:]
self.eye_data = infile["eye_data"][:]
self.channel_names = infile.attrs["channel_names"]
self.n_channels = infile.attrs["n_channels"]
self.n_levels = infile.attrs["n_levels"]
self.eye_stats = infile.attrs["eye_stats"]
self.eye_stats = [json.loads(stat) for stat in self.eye_stats]
self.horizontal_bins = infile.attrs["horizontal_bins"]
self.vertical_bins = infile.attrs["vertical_bins"]
self.multi_threaded = infile.attrs["multithreaded"]
self.analysed = infile.attrs["analysed"]
self.eye_built = infile.attrs["eye_built"]
def save_data(self, file_or_dir=None):
file_or_dir = self.save_file if file_or_dir is None else file_or_dir
if file_or_dir is None:
file = Path(f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5")
elif Path(file_or_dir).is_dir():
file = Path(file_or_dir) / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5"
else:
for i, data in enumerate(datas):
self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
self.eye_built = True
file = Path(file_or_dir)
# file.parent.mkdir(parents=True, exist_ok=True)
self.save_file = str(file)
with h5py.File(file, "w") as outfile:
outfile.create_dataset("eye_data", data=self.eye_data)
outfile.create_dataset("y_bins", data=self.y_bins)
outfile.create_dataset("x_bins", data=self.x_bins)
outfile.attrs["channel_names"] = self.channel_names
outfile.attrs["n_channels"] = self.n_channels
outfile.attrs["n_levels"] = self.n_levels
self.eye_stats = eye_diagram.convert_arrays(self.eye_stats)
outfile.attrs["eye_stats"] = [json.dumps(stat) for stat in self.eye_stats]
outfile.attrs["horizontal_bins"] = self.horizontal_bins
outfile.attrs["vertical_bins"] = self.vertical_bins
outfile.attrs["multithreaded"] = self.multi_threaded
outfile.attrs["analysed"] = self.analysed
outfile.attrs["eye_built"] = self.eye_built
@staticmethod
def convert_arrays(input_object):
"""
convert ndarrays in (nested) dict to lists
"""
if isinstance(input_object, np.ndarray):
return input_object.tolist()
elif isinstance(input_object, list):
return [eye_diagram.convert_arrays(old) for old in input_object]
elif isinstance(input_object, tuple):
return tuple(eye_diagram.convert_arrays(old) for old in input_object)
elif isinstance(input_object, dict):
dict_out = {}
for key, value in input_object.items():
dict_out[key] = eye_diagram.convert_arrays(value)
return dict_out
return input_object
def generate_eye_data(
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
):
# modes:
# default: try to load eye data from file, if not found, generate and save
# load: try to load eye data from file, if not found, generate but don't save
# save: generate eye data and save
update_save = True
if mode == "load":
self.load_data(file_or_dir)
update_save = False
elif mode == "default":
try:
self.load_data(file_or_dir)
update_save = False
except (FileNotFoundError, IsADirectoryError):
pass
if not self.eye_built:
update_save = True
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
self.y_bins = np.zeros((self.n_channels, self.vertical_bins))
self.eye_data = np.zeros((self.n_channels, self.vertical_bins, self.horizontal_bins))
datas = [self.raw_data[i] for i in range(self.n_channels)]
if self.multi_threaded:
with multiprocessing.Pool() as pool:
results = pool.map(self.generate_eye_data_single, datas)
for i, result in enumerate(results):
self.eye_data[i], self.y_bins[i] = result
else:
for i, data in enumerate(datas):
self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
self.eye_built = True
if mode == "save" or (mode == "default" and update_save):
self.save_data(file_or_dir)
def generate_eye_data_single(self, data):
eye_data = np.zeros((self.vertical_bins, self.horizontal_bins))
data_min = np.min(data[1, :])
data_max = np.max(data[1, :])
# round down/up to 1 decimal
data_min = np.floor(data_min*10)/10
data_max = np.ceil(data_max*10)/10
# data_range = data_max - data_min
# data_min -= 0.1 * data_range
# data_max += 0.1 * data_range
# data_min = -0.05
# data_max += 0.05
# data[1,:] -= np.min(data[1, :])
# data[1,:] /= np.max(data[1, :])
# data_min = 0
# data_max = 1
y_bins = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
t_vals = data[0, :] % 2
val_vals = data[1, :]
t_vals = data[0, :] % 2 # + np.random.randn(*data[0, :].shape) * 1 / (512)
val_vals = data[1, :] # + np.random.randn(*data[1, :].shape) * 1 / (320)
x_indices = np.digitize(t_vals, self.x_bins) - 1
y_indices = np.digitize(val_vals, y_bins) - 1
np.add.at(eye_data, (y_indices, x_indices), 1)
return eye_data, y_bins
def plot(self, title="Eye Diagram", stats=True, all_stats=True, show=True):
def plot(
self,
title="Eye Diagram",
stats=True,
all_stats=True,
show=True,
mode: Literal["default", "load", "save", "nosave"] = "default",
# save_images = False,
# image_dir = None,
# cmap=None,
):
if stats and not self.analysed:
self.analyse(mode=mode)
if not self.eye_built:
self.generate_eye_data()
self.generate_eye_data(mode=mode)
cmap = LinearSegmentedColormap.from_list(
"eyemap",
[(0, "white"), (0.1, "blue"), (0.2, "cyan"), (0.5, "green"), (0.8, "yellow"), (0.9, "red"), (1, "magenta")],
[
(0, "#FFFFFF00"),
(0.1, "blue"),
(0.2, "cyan"),
(0.5, "green"),
(0.8, "yellow"),
(0.9, "red"),
(1, "magenta"),
],
)
if self.channels % 2 == 0:
# cmap = cm('google:turbo_r' if cmap is None else cmap)
# first = cmap(-1)
# cmap = cmap.to_mpl()
# cmap.set_under(first, alpha=0)
if self.n_channels % 2 == 0:
rows = 2
cols = self.channels // 2
cols = self.n_channels // 2
else:
cols = int(np.ceil(np.sqrt(self.channels)))
rows = int(np.ceil(self.channels / cols))
cols = int(np.ceil(np.sqrt(self.n_channels)))
rows = int(np.ceil(self.n_channels / cols))
fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False)
fig.suptitle(title)
fig.tight_layout()
ax = np.atleast_1d(ax).transpose().flatten()
for i in range(self.channels):
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}")
if (i+1) % rows == 0:
for i in range(self.n_channels):
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i + 1}")
if (i + 1) % rows == 0:
ax[i].set_xlabel("Symbol")
if i < rows:
ax[i].set_ylabel("Amplitude")
ax[i].grid()
ax[i].set_axisbelow(True)
ax[i].imshow(
self.eye_data[i],
self.eye_data[i] - 0.1,
origin="lower",
aspect="auto",
cmap=cmap,
extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
interpolation="gaussian",
vmin=0,
zorder=3,
)
ax[i].set_xlim((self.x_bins[0], self.x_bins[-1]))
ymin = np.min(self.y_bins[:, 0])
ymax = np.max(self.y_bins[:, -1])
yspan = ymax - ymin
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
# if save_images:
# image_dir = "images_out" if image_dir is None else image_dir
# image_path = Path(image_dir) / (slugify(f"{datetime.now().strftime("%Y%m%d_%H%M%S")}_{title.replace(" ","_")}_{self.channel_names[i].replace(" ", "_") if self.channel_names is not None else f"{i + 1}"}_{ymin:.1f}_{ymax:.1f}") + ".png")
# image_path.parent.mkdir(parents=True, exist_ok=True)
# # plt.imsave(
# # image_path,
# # self.eye_data[i] - 0.1,
# # origin="lower",
# # # aspect="auto",
# # cmap=cmap,
# # # extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
# # # interpolation="gaussian",
# # vmin=0,
# # # zorder=3,
# # )
if stats and self.eye_stats[i]["success"]:
# # add min_area above the plot
# ax[i].annotate(
@@ -159,7 +339,7 @@ class eye_diagram:
if all_stats:
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
y_ticks = (*self.eye_stats[i]["levels"],*self.eye_stats[i]["thresholds"])
y_ticks = (*self.eye_stats[i]["levels"], *self.eye_stats[i]["thresholds"])
# y_ticks = np.sort(y_ticks)
ax[i].set_yticks(y_ticks)
# add arrows for amplitudes
@@ -235,19 +415,19 @@ class eye_diagram:
def calculate_thresholds(levels):
ret = np.cumsum(levels, dtype=float)
ret[2:] = ret[2:] - ret[:-2]
return ret[1:]/2
return ret[1:] / 2
def analyse_single(self, data, index):
warnings.filterwarnings("error")
eye_stats = {}
eye_stats["channel_name"] = str(index+1) if self.channel_names is None else self.channel_names[index]
eye_stats["channel_name"] = str(index + 1) if self.channel_names is None else self.channel_names[index]
try:
approx_levels = eye_diagram.approximate_levels(data, self.n_levels)
time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels)
eye_stats["time_midpoint_calc"] = (time_bounds[0] + time_bounds[1]) / 2
eye_stats["time_midpoint"] = 1.0
eye_stats["time_midpoint"] = float((time_bounds[0] + time_bounds[1]) / 2)
# eye_stats["time_midpoint"] = 1.0
eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels(
data, approx_levels, time_bounds
@@ -257,9 +437,7 @@ class eye_diagram:
eye_stats["amplitudes"] = np.diff(eye_stats["levels"])
eye_stats["heights"] = eye_diagram.calculate_eye_heights(
eye_stats["amplitude_clusters"]
)
eye_stats["heights"] = eye_diagram.calculate_eye_heights(eye_stats["amplitude_clusters"])
eye_stats["widths"], eye_stats["time_clusters"] = eye_diagram.calculate_eye_widths(
data, eye_stats["levels"]
@@ -291,17 +469,39 @@ class eye_diagram:
warnings.resetwarnings()
return eye_stats
def analyse(
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
):
# modes:
# default: try to load eye data from file, if not found, generate and save
# load: try to load eye data from file, if not found, generate but don't save
# save: generate eye data and save
update_save = True
if mode == "load":
self.load_data(file_or_dir)
update_save = False
elif mode == "default":
try:
self.load_data(file_or_dir)
update_save = False
except (FileNotFoundError, IsADirectoryError):
pass
def analyse(self):
self.eye_stats = []
if self.multi_threaded:
with multiprocessing.Pool() as pool:
results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.channels)])
for i, result in enumerate(results):
self.eye_stats.append(result)
else:
for i in range(self.channels):
self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
if not self.analysed:
update_save = True
self.eye_stats = []
if self.multi_threaded:
with multiprocessing.Pool() as pool:
results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.n_channels)])
for i, result in enumerate(results):
self.eye_stats.append(result)
else:
for i in range(self.n_channels):
self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
self.analysed = True
if mode == "save" or (mode == "default" and update_save):
self.save_data(file_or_dir)
@staticmethod
def approximate_levels(data, levels):
@@ -443,7 +643,7 @@ class eye_diagram:
if __name__ == "__main__":
length = int(2**14)
length = int(2**16)
# data = generate_sample_data(length, noise=1)
# data1 = generate_sample_data(length, noise=0.01)
# data2 = generate_sample_data(length, noise=0.01, skew=1.2)
@@ -451,13 +651,13 @@ if __name__ == "__main__":
# data = np.stack([data, data1, data2, data3])
data = generate_sample_data(length, noise=0.005)
eye = eye_diagram(data, horizontal_bins=256, vertical_bins=256)
eye.analyse()
attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area")
for i, channel in enumerate(eye.eye_stats):
print(f"Channel {i}")
print_data = {attr: channel[attr] for attr in attrs}
print(print_data)
data = generate_sample_data(length, noise=0.0000)
eye = eye_diagram(data, horizontal_bins=1056, vertical_bins=1200)
eye.plot(mode="nosave", stats=False)
# attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area")
# for i, channel in enumerate(eye.eye_stats):
# print(f"Channel {i}")
# print_data = {attr: channel[attr] for attr in attrs}
# print(print_data)
eye.plot()
# eye.plot()

View File

@@ -0,0 +1,122 @@
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
# This software is licensed according to the "BSD 2-clause" license.
# modified by Joseph Hopfmüller in 2025,
# for integration into optical regeneration analysis scripts
from pathlib import Path
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.colors as colors
import numpy as _np
from .core import grid_count as _grid_count
import matplotlib.pyplot as _plt
import numpy as np
from scipy.ndimage import gaussian_filter
# from ._common import _common_doc
__all__ = ["eyediagram"] # , 'eyediagram_lines']
# def eyediagram_lines(y, window_size, offset=0, **plotkwargs):
# """
# Plot an eye diagram using matplotlib by repeatedly calling the `plot`
# function.
# <common>
# """
# start = offset
# while start < len(y):
# end = start + window_size
# if end > len(y):
# end = len(y)
# yy = y[start:end+1]
# _plt.plot(_np.arange(len(yy)), yy, 'k', **plotkwargs)
# start = end
# eyediagram_lines.__doc__ = eyediagram_lines.__doc__.replace("<common>",
# _common_doc)
eyemap = LinearSegmentedColormap.from_list(
"eyemap",
[
(0, "#0000FF00"),
(0.1, "blue"),
(0.2, "cyan"),
(0.5, "green"),
(0.8, "yellow"),
(0.9, "red"),
(1, "magenta"),
],
)
def eyediagram(
y,
window_size,
offset=0,
colorbar=False,
show=False,
save_im=False,
overwrite=False,
blur: int | bool = True,
save_path="out.png",
bounds=None,
**imshowkwargs,
):
"""
Plot an eye diagram using matplotlib by creating an image and calling
the `imshow` function.
<common>
"""
if bounds is None:
ymax = y.max()
ymin = y.min()
yamp = ymax - ymin
ymin = ymin - 0.05 * yamp
ymax = ymax + 0.05 * yamp
ymin = np.floor(ymin * 10) / 10
ymax = np.ceil(ymax * 10) / 10
bounds = (ymin, ymax)
counts = _grid_count(y, window_size, offset, bounds=bounds, size=(1000, 1200), blur=int(blur))
counts = counts.astype(_np.float32)
origin = imshowkwargs.pop("origin", "lower")
cmap: colors.Colormap = imshowkwargs.pop("cmap", eyemap)
vmin = imshowkwargs.pop("vmin", 1)
vmax = imshowkwargs.pop("vmax", None)
cmap.set_under("white", alpha=0)
if show:
_plt.imshow(
counts.T[::-1, :],
extent=[0, 2, *bounds],
origin=origin,
cmap=cmap,
vmin=vmin,
vmax=vmax,
**imshowkwargs,
)
_plt.grid()
if colorbar:
_plt.colorbar()
if Path(save_path).is_file() and not overwrite:
save_im = False
if save_im:
from PIL import Image
arr = counts.T[::-1, :]
if origin == "lower":
arr = arr[::-1]
arr = (arr-arr.min())/(arr.max()-arr.min())
image = Image.fromarray((cmap(arr)[:, :, :] * 255).astype(np.uint8))
image.save(save_path)
# print("-")
if show:
_plt.show()
# eyediagram.__doc__ = eyediagram.__doc__.replace("<common>", _common_doc)