This commit is contained in:
Joseph Hopfmüller
2025-01-27 21:05:49 +01:00
parent f38d0ca3bb
commit 249fe1e940
19 changed files with 2266 additions and 880 deletions

3
.gitignore vendored
View File

@@ -163,4 +163,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
tolerance_results/datasets/* tolerance_results/*
data/*

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fcbdaffa211d6b0b44b3ae1c66645999e95901bfdb2fffee4c45e34a0d901ee1
size 649

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1df90745cc2e6d4b0ad964fca2de1441e6e0b4b8345fbb0fbc1ffe9820674269
size 134481920

View File

@@ -0,0 +1,59 @@
# Baseline Models
## a) D+S, pol_error 0, ortho_error 0, DGD 0
dataset
```raw
data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250118_225918.tar
```
## b) D+S, pol_error 0.4, ortho_error 0, DGD 0
dataset
```raw
data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250116_214816.tar
```
## c) D+S, pol_error 0, ortho_error 0.1, DGD 0
dataset
```raw
data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250117_122319.tar
```
## d) D+S, pol_error 0, ortho_error 0, DGD 10ps (1 T_sym)
birefringence angle pi/2 (worst case)
dataset
```raw
data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250117_144001.tar
```

2
pypho

Submodule pypho updated: dd015f4852...e44fc477fe

View File

@@ -164,10 +164,14 @@ class regenerator(Module):
module = act_function(size=dims[-1], **act_func_kwargs) module = act_function(size=dims[-1], **act_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module) self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
module = Scale(size=dims[-1])
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
if self.rotation: if self.rotation:
module = rotate() module = rotate()
self.add_module("rotate", module) self.add_module("rotate", module)
# module = Scale(size=dims[-1]) # module = Scale(size=dims[-1])
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module) # self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)

View File

@@ -18,9 +18,11 @@ class DataSettings:
shuffle: bool = True shuffle: bool = True
in_out_delay: float = 0 in_out_delay: float = 0
xy_delay: tuple | float | int = 0 xy_delay: tuple | float | int = 0
drop_first: int = 1000 drop_first: int = 64
drop_last: int = 64
train_split: float = 0.8 train_split: float = 0.8
polarisations: tuple | list = (0,) polarisations: tuple | list = (0,)
# cross_pol_interference: float = 0
randomise_polarisations: bool = False randomise_polarisations: bool = False
osnr: float | int = None osnr: float | int = None
seed: int = None seed: int = None
@@ -93,6 +95,12 @@ class ModelSettings:
""" """
def _early_stop_default_kwargs():
return {
"threshold": 1e-05,
"plateau": 25,
}
@dataclass @dataclass
class OptimizerSettings: class OptimizerSettings:
optimizer: tuple | str = ("Adam", "RMSprop", "SGD") optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
@@ -101,6 +109,9 @@ class OptimizerSettings:
scheduler: str | None = None scheduler: str | None = None
scheduler_kwargs: dict | None = None scheduler_kwargs: dict | None = None
early_stopping: bool = False
early_stop_kwargs: dict = field(default_factory=_early_stop_default_kwargs)
""" """
change to: change to:

View File

@@ -4,6 +4,7 @@ from pathlib import Path
import random import random
import matplotlib import matplotlib
from matplotlib.colors import LinearSegmentedColormap from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits.axes_grid1 import make_axes_locatable
import torch.nn.utils.parametrize import torch.nn.utils.parametrize
try: try:
@@ -46,13 +47,72 @@ from .settings import (
PytorchSettings, PytorchSettings,
) )
from cmcrameri import cm
# from matplotlib import colors as mcolors
# alpha_map = mcolors.LinearSegmentedColormap(
# 'alphamap',
# {
# 'red': [(0, 0, 0), (1, 0, 0)],
# 'green': [(0, 0, 0), (1, 0, 0)],
# 'blue': [(0, 0, 0), (1, 0, 0)],
# 'alpha': [
# (0, 1, 1),
# # (0.2, 0.2, 0.1),
# (1, 0, 0)
# ]
# }
# )
# alpha_map.set_bad(color="#AAAAAA")
def pad_to_size(array, size):
if not hasattr(size, "__len__"):
size = (size, size)
left = (
(size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0
)
right = (
(size[0] - array.shape[0]) // 2 if size[0] is not None else 0
)
top = (
(size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0
)
bottom = (
(size[1] - array.shape[1]) // 2 if size[1] is not None else 0
)
array: np.ndarray = array
if array.ndim == 2:
return np.pad(
array,
(
(left, right),
(top, bottom),
),
constant_values=(np.nan, np.nan),
)
elif array.ndim == 3:
return np.pad(
array,
(
(left, right),
(top, bottom),
(0,0)
),
constant_values=(np.nan, np.nan),
)
def traverse_dict_update(target, source): def traverse_dict_update(target, source):
for k, v in source.items(): for k, v in source.items():
if isinstance(v, dict): if isinstance(v, dict):
try:
if k not in target: if k not in target:
target[k] = {} target[k] = {}
traverse_dict_update(target[k], v) traverse_dict_update(target[k], v)
except TypeError:
if k not in target.__dict__:
setattr(target, k, {})
traverse_dict_update(target.__dict__[k], v)
else: else:
try: try:
target[k] = v target[k] = v
@@ -261,6 +321,7 @@ class PolarizationTrainer:
target_delay=in_out_delay, target_delay=in_out_delay,
xy_delay=xy_delay, xy_delay=xy_delay,
drop_first=self.data_settings.drop_first, drop_first=self.data_settings.drop_first,
drop_last=self.data_settings.drop_last,
dtype=dtype, dtype=dtype,
real=not dtype.is_complex, real=not dtype.is_complex,
num_symbols=num_symbols, num_symbols=num_symbols,
@@ -602,6 +663,7 @@ class RegenerationTrainer:
console=None, console=None,
checkpoint_path=None, checkpoint_path=None,
settings_override=None, settings_override=None,
new_model=False,
reset_epoch=False, reset_epoch=False,
): ):
self.resume = checkpoint_path is not None self.resume = checkpoint_path is not None
@@ -615,12 +677,23 @@ class RegenerationTrainer:
models.regenerator, models.regenerator,
torch.nn.utils.parametrizations.orthogonal, torch.nn.utils.parametrizations.orthogonal,
]) ])
# self.new_model = True
self.model_name = datetime.now().strftime("%Y%m%d_%H%M%S")
if self.resume: if self.resume:
print(f"loading checkpoint from {checkpoint_path}") print(f"loading checkpoint from {checkpoint_path}")
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True) self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
if settings_override is not None: if settings_override is not None:
traverse_dict_update(self.checkpoint_dict["settings"], settings_override) traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
if reset_epoch:
if not new_model:
# self.new_model = False
checkpoint_file = checkpoint_path.split("/")[-1].split(".")[0]
if checkpoint_file.startswith("best"):
self.model_name = "_".join(checkpoint_file.split("_")[1:])
else:
self.model_name = "_".join(checkpoint_file.split("_")[:-1])
if new_model or reset_epoch:
self.checkpoint_dict["epoch"] = -1 self.checkpoint_dict["epoch"] = -1
self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"] self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"]
@@ -654,7 +727,7 @@ class RegenerationTrainer:
self.writer = None self.writer = None
def setup_tb_writer(self, append=None): def setup_tb_writer(self, append=None):
log_dir = self.pytorch_settings.summary_dir + "/" + (datetime.now().strftime("%Y%m%d_%H%M%S")) log_dir = self.pytorch_settings.summary_dir + "/" + self.model_name
if append is not None: if append is not None:
log_dir += "_" + str(append) log_dir += "_" + str(append)
@@ -697,7 +770,7 @@ class RegenerationTrainer:
output_dim = self.model_settings.output_dim output_dim = self.model_settings.output_dim
# if self.data_settings.polarisations: if self.data_settings.polarisations:
output_dim *= 2 output_dim *= 2
dtype = getattr(torch, self.data_settings.dtype) dtype = getattr(torch, self.data_settings.dtype)
@@ -755,11 +828,13 @@ class RegenerationTrainer:
randomise_polarisations = self.data_settings.randomise_polarisations randomise_polarisations = self.data_settings.randomise_polarisations
polarisations = self.data_settings.polarisations polarisations = self.data_settings.polarisations
osnr = self.data_settings.osnr osnr = self.data_settings.osnr
# cross_pol_interference = self.data_settings.cross_pol_interference
if override is not None: if override is not None:
num_symbols = override.get("num_symbols", None) num_symbols = override.get("num_symbols", None)
config_path = override.get("config_path", config_path) config_path = override.get("config_path", config_path)
polarisations = override.get("polarisations", polarisations) polarisations = override.get("polarisations", polarisations)
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations) randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
# cross_pol_interference = override.get("angle_var", 0)
# get dataset # get dataset
dataset = FiberRegenerationDataset( dataset = FiberRegenerationDataset(
file_path=config_path, file_path=config_path,
@@ -768,11 +843,13 @@ class RegenerationTrainer:
target_delay=in_out_delay, target_delay=in_out_delay,
xy_delay=xy_delay, xy_delay=xy_delay,
drop_first=self.data_settings.drop_first, drop_first=self.data_settings.drop_first,
drop_last=self.data_settings.drop_last,
dtype=dtype, dtype=dtype,
real=not dtype.is_complex, real=not dtype.is_complex,
num_symbols=num_symbols, num_symbols=num_symbols,
randomise_polarisations=randomise_polarisations, randomise_polarisations=randomise_polarisations,
polarisations=polarisations, polarisations=polarisations,
# cross_pol_interference=cross_pol_interference,
osnr = osnr, osnr = osnr,
) )
@@ -842,8 +919,10 @@ class RegenerationTrainer:
running_loss = 0.0 running_loss = 0.0
self.model.train() self.model.train()
loader_len = len(train_loader) loader_len = len(train_loader)
x_key = "x_stacked"# if self.data_settings.polarisations else "x" # x_key = "x_stacked"# if self.data_settings.polarisations else "x"
y_key = "y_stacked"# if self.data_settings.polarisations else "y" # y_key = "y_stacked"# if self.data_settings.polarisations else "y"
x_key = "x"
y_key = "y"
for batch_idx, batch in enumerate(train_loader): for batch_idx, batch in enumerate(train_loader):
x = batch[x_key] x = batch[x_key]
y = batch[y_key] y = batch[y_key]
@@ -855,7 +934,10 @@ class RegenerationTrainer:
angle.to(self.pytorch_settings.device), angle.to(self.pytorch_settings.device),
) )
y_pred = self.model(x, -angle) y_pred = self.model(x, -angle)
# loss = util.complexNN.complex_mse_loss(y_pred*torch.sqrt(torch.ones(1, device=y.device)*5), y*torch.sqrt(torch.ones(1, device=y.device)*5), power=True)
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True) loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
loss_value = loss.item() loss_value = loss.item()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@@ -898,8 +980,10 @@ class RegenerationTrainer:
self.model.eval() self.model.eval()
running_error = 0 running_error = 0
x_key = "x_stacked"# if self.data_settings.polarisations else "x" # x_key = "x_stacked"# if self.data_settings.polarisations else "x"
y_key = "y_stacked"# if self.data_settings.polarisations else "y" # y_key = "y_stacked"# if self.data_settings.polarisations else "y"
x_key = "x"
y_key = "y"
with torch.no_grad(): with torch.no_grad():
for _, batch in enumerate(valid_loader): for _, batch in enumerate(valid_loader):
x = batch[x_key] x = batch[x_key]
@@ -911,7 +995,9 @@ class RegenerationTrainer:
angle.to(self.pytorch_settings.device), angle.to(self.pytorch_settings.device),
) )
y_pred = self.model(x, -angle) y_pred = self.model(x, -angle)
# error = util.complexNN.complex_mse_loss(y_pred*torch.sqrt(torch.ones(1, device=y.device)*5), y*torch.sqrt(torch.ones(1, device=y.device)*5), power=True)
error = util.complexNN.complex_mse_loss(y_pred, y, power=True) error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
error_value = error.item() error_value = error.item()
running_error += error_value running_error += error_value
@@ -928,7 +1014,7 @@ class RegenerationTrainer:
if (epoch + 1) % 10 == 0 or epoch < 10: if (epoch + 1) % 10 == 0 or epoch < 10:
# plotting is slow, so only do it every 10 epochs # plotting is slow, so only do it every 10 epochs
title_append, subtitle = self.build_title(epoch + 1) title_append, subtitle = self.build_title(epoch + 1)
head_fig, eye_fig, powers_fig = self.plot_model_response( head_fig, eye_fig, weight_fig, powers_fig = self.plot_model_response(
model=self.model, model=self.model,
title_append=title_append, title_append=title_append,
subtitle=subtitle, subtitle=subtitle,
@@ -944,6 +1030,11 @@ class RegenerationTrainer:
eye_fig, eye_fig,
epoch + 1, epoch + 1,
) )
self.writer.add_figure(
"weights",
weight_fig,
epoch + 1,
)
self.writer.add_figure( self.writer.add_figure(
"powers", "powers",
@@ -967,9 +1058,10 @@ class RegenerationTrainer:
regen = [] regen = []
timestamps = [] timestamps = []
angles = [] angles = []
x_key = "x_stacked"# if self.data_settings.polarisations else "x" # x_key = "x_stacked"# if self.data_settings.polarisations else "x"
y_key = "y_stacked"# if self.data_settings.polarisations else "y" # y_key = "y_stacked"# if self.data_settings.polarisations else "y"
x_key = "x"
y_key = "y"
with torch.no_grad(): with torch.no_grad():
model = model.to(self.pytorch_settings.device) model = model.to(self.pytorch_settings.device)
for batch in loader: for batch in loader:
@@ -1056,7 +1148,7 @@ class RegenerationTrainer:
) )
title_append, subtitle = self.build_title(0) title_append, subtitle = self.build_title(0)
head_fig, eye_fig, powers_fig = self.plot_model_response( head_fig, eye_fig, weight_fig, powers_fig = self.plot_model_response(
model=self.model, model=self.model,
title_append=title_append, title_append=title_append,
subtitle=subtitle, subtitle=subtitle,
@@ -1072,6 +1164,11 @@ class RegenerationTrainer:
eye_fig, eye_fig,
0, 0,
) )
self.writer.add_figure(
"weights",
weight_fig,
0,
)
self.writer.add_figure( self.writer.add_figure(
"powers", "powers",
@@ -1103,6 +1200,9 @@ class RegenerationTrainer:
train_loader, valid_loader = self.get_sliced_data() train_loader, valid_loader = self.get_sliced_data()
# train_loader.dataset.fiber_out.to(self.pytorch_settings.device)
# train_loader.dataset.fiber_in.to(self.pytorch_settings.device)
optimizer_name = self.optimizer_settings.optimizer optimizer_name = self.optimizer_settings.optimizer
# lr = self.optimizer_settings.learning_rate # lr = self.optimizer_settings.learning_rate
@@ -1132,6 +1232,7 @@ class RegenerationTrainer:
# except ValueError: # except ValueError:
# pass # pass
self.early_stop_vals = {"min_loss": float("inf"), "plateau_cnt": 0}
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs): for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
enable_progress = True enable_progress = True
if enable_progress: if enable_progress:
@@ -1147,9 +1248,48 @@ class RegenerationTrainer:
epoch, epoch,
enable_progress=enable_progress, enable_progress=enable_progress,
) )
if self.early_stop(loss):
self.save_model_checkpoints(epoch, loss)
break
if self.optimizer_settings.scheduler is not None: if self.optimizer_settings.scheduler is not None:
self.scheduler.step(loss) self.scheduler.step(loss)
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch) self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch)
self.save_model_checkpoints(epoch, loss)
self.writer.flush()
save_path = (Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar")
print(f"Training complete. Best checkpoint: {save_path}")
self.writer.close()
return self.best
def early_stop(self, loss):
# not stopping early at all
if not self.optimizer_settings.early_stopping:
return False
# stopping because of abs threshold
if (loss_thr := self.optimizer_settings.early_stop_kwargs.get("threshold", None)) is not None:
if loss <= loss_thr:
print(f"Early stop: loss is below threshold ({loss:.2e} <= {loss_thr:.2e})")
return True
# update vals
if loss < self.early_stop_vals["min_loss"]:
self.early_stop_vals["min_loss"] = loss
self.early_stop_vals["plateau_cnt"] = 0
return False
# stopping because of plateau
if (plateau_thresh := self.optimizer_settings.early_stop_kwargs.get("plateau", None)) is not None:
self.early_stop_vals["plateau_cnt"] += 1
if self.early_stop_vals["plateau_cnt"] >= plateau_thresh:
print(f"Early stop: loss plateau length over threshold ({self.early_stop_vals["plateau_cnt"]} >= {plateau_thresh})")
return True
# no stop
return False
def save_model_checkpoints(self, epoch, loss):
if self.pytorch_settings.save_models and self.model is not None: if self.pytorch_settings.save_models and self.model is not None:
save_path = ( save_path = (
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar" Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
@@ -1165,10 +1305,6 @@ class RegenerationTrainer:
) )
save_path.parent.mkdir(parents=True, exist_ok=True) save_path.parent.mkdir(parents=True, exist_ok=True)
self.save_checkpoint(self.best, save_path) self.save_checkpoint(self.best, save_path)
self.writer.flush()
self.writer.close()
return self.best
def _plot_model_response_powers(self, powers, layer_names, title_append="", subtitle="", show=True): def _plot_model_response_powers(self, powers, layer_names, title_append="", subtitle="", show=True):
powers = [power / powers[0] for power in powers] powers = [power / powers[0] for power in powers]
@@ -1190,6 +1326,77 @@ class RegenerationTrainer:
plt.show() plt.show()
return fig return fig
def _plot_model_weights(self, model, title_append="", subtitle="", show=True):
model_params = []
plots = []
dims = []
for num, (layer_name, layer) in enumerate(model.named_children()):
onn_weights = layer.ONN.weight
onn_weights = onn_weights.detach().cpu().numpy()
onn_values = np.abs(onn_weights).real
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
model_params.append({layer_name: onn_weights})
plots.append({layer_name: (num, onn_values, onn_angles)})
dims.append(onn_weights.shape[0])
max_size = np.max(dims)
for plot in plots:
layer_name, (num, onn_values, onn_angles) = plot.popitem()
if num == 0:
value_img = onn_values
angle_img = onn_angles
onn_angles = pad_to_size(onn_angles, (max_size, None))
onn_values = pad_to_size(onn_values, (max_size, None))
else:
onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1))
onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1))
value_img = np.concatenate((value_img, onn_values), axis=1)
angle_img = np.concatenate((angle_img, onn_angles), axis=1)
value_img = np.ma.array(value_img, mask=np.isnan(value_img))
angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img))
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(18, 6.5))
fig.tight_layout()
dividers = map(make_axes_locatable, axs)
caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers))
masked_value_img = value_img
cmap = cm.batlow
cmap.set_bad(color="#AAAAAA")
im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1)
fig.colorbar(im_val, cax=caxs[0], orientation="vertical")
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
cmap = cm.romaO
cmap.set_bad(color="#AAAAAA")
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi)
cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)])
cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", ""])
axs[0].axis("off")
axs[1].axis("off")
axs[0].set_title("Values")
axs[1].set_title("Angles")
title = "Layer Weights"
if title_append:
title += f" {title_append}"
if subtitle:
title += f"\n{subtitle}"
fig.suptitle(title)
if show:
plt.show()
return fig
def _plot_model_response_eye( def _plot_model_response_eye(
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
): ):
@@ -1354,7 +1561,7 @@ class RegenerationTrainer:
data_settings_backup = copy.deepcopy(self.data_settings) data_settings_backup = copy.deepcopy(self.data_settings)
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings) pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
self.data_settings.drop_first = 99.5 + random.randint(0, 1000) self.data_settings.drop_first = int(64 + random.randint(0, 1000))
self.data_settings.shuffle = False self.data_settings.shuffle = False
self.data_settings.train_split = 1.0 self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols) self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols)
@@ -1363,7 +1570,7 @@ class RegenerationTrainer:
if isinstance(self.data_settings.config_path, (list, tuple)) if isinstance(self.data_settings.config_path, (list, tuple))
else self.data_settings.config_path else self.data_settings.config_path
) )
fiber_length = int(float(str(config_path).split("-")[4]) / 1000) # fiber_length = int(float(str(config_path).split("-")[4]) / 1000)
if not hasattr(self, "_plot_loader"): if not hasattr(self, "_plot_loader"):
self._plot_loader, _ = self.get_sliced_data( self._plot_loader, _ = self.get_sliced_data(
override={ override={
@@ -1376,6 +1583,7 @@ class RegenerationTrainer:
} }
) )
self._sps = self._plot_loader.dataset.samples_per_symbol self._sps = self._plot_loader.dataset.samples_per_symbol
fiber_length = float(self._plot_loader.dataset.config["fiber"]["length"])/1000
self.data_settings = data_settings_backup self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup self.pytorch_settings = pytorch_settings_backup
@@ -1403,7 +1611,7 @@ class RegenerationTrainer:
import gc import gc
head_fig = self._plot_model_response_head( head_fig = self._plot_model_response_head(
fiber_out_rot[: self.pytorch_settings.head_symbols * self._sps], fiber_out[: self.pytorch_settings.head_symbols * self._sps],
fiber_in[: self.pytorch_settings.head_symbols * self._sps], fiber_in[: self.pytorch_settings.head_symbols * self._sps],
regen[: self.pytorch_settings.head_symbols * self._sps], regen[: self.pytorch_settings.head_symbols * self._sps],
angles[: self.pytorch_settings.head_symbols * self._sps], angles[: self.pytorch_settings.head_symbols * self._sps],
@@ -1417,7 +1625,7 @@ class RegenerationTrainer:
# raise NotImplementedError("Eye diagram not implemented") # raise NotImplementedError("Eye diagram not implemented")
eye_fig = self._plot_model_response_eye( eye_fig = self._plot_model_response_eye(
fiber_in[: self.pytorch_settings.eye_symbols * self._sps], fiber_in[: self.pytorch_settings.eye_symbols * self._sps],
fiber_out_rot[: self.pytorch_settings.eye_symbols * self._sps], fiber_out[: self.pytorch_settings.eye_symbols * self._sps],
regen[: self.pytorch_settings.eye_symbols * self._sps], regen[: self.pytorch_settings.eye_symbols * self._sps],
timestamps=timestamps[: self.pytorch_settings.eye_symbols * self._sps], timestamps=timestamps[: self.pytorch_settings.eye_symbols * self._sps],
labels=("fiber in", "fiber out", "regen"), labels=("fiber in", "fiber out", "regen"),
@@ -1426,9 +1634,11 @@ class RegenerationTrainer:
subtitle=subtitle, subtitle=subtitle,
show=show, show=show,
) )
weight_fig = self._plot_model_weights(model, title_append=title_append, subtitle=subtitle, show=show)
gc.collect() gc.collect()
return head_fig, eye_fig, power_fig return head_fig, eye_fig, weight_fig, power_fig
def build_title(self, number: int): def build_title(self, number: int):
title_append = f"epoch {number}" title_append = f"epoch {number}"

View File

@@ -1,4 +1,6 @@
import os from pathlib import Path
import sys
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import numpy as np import numpy as np
import torch import torch
@@ -26,6 +28,28 @@ from hypertraining import models
# constant_values=(-np.inf, -np.inf), # constant_values=(-np.inf, -np.inf),
# ) # )
def register_puccs_cmap(puccs_path=None):
puccs_path = Path(__file__).resolve().parent / 'puccs.csv' if puccs_path is None else puccs_path
colors = []
# keys = None
with open(puccs_path, "r") as f:
for i, line in enumerate(f.readlines()):
elements = tuple(line.split(","))
# if i == 0:
# # keys = elements
# continue
# else:
try:
colors.append(tuple(map(float, elements[4:])))
except ValueError:
continue
# colors = []
# for current in puccs_csv_data:
# colors.append(tuple(current[4:]))
from matplotlib.colors import LinearSegmentedColormap
import matplotlib as mpl
mpl.colormaps.register(LinearSegmentedColormap.from_list('puccs', colors))
def pad_to_size(array, size): def pad_to_size(array, size):
if not hasattr(size, "__len__"): if not hasattr(size, "__len__"):
@@ -65,7 +89,7 @@ def pad_to_size(array, size):
constant_values=(np.nan, np.nan), constant_values=(np.nan, np.nan),
) )
def model_plot(model_path): def model_plot(model_path, show=True):
torch.serialization.add_safe_globals([ torch.serialization.add_safe_globals([
*util.complexNN.__all__, *util.complexNN.__all__,
GlobalSettings, GlobalSettings,
@@ -81,173 +105,113 @@ def model_plot(model_path):
dims = checkpoint_dict["model_kwargs"].pop("dims") dims = checkpoint_dict["model_kwargs"].pop("dims")
model = models.regenerator(*dims, **checkpoint_dict["model_kwargs"]) model = models.regenerator(*dims, **checkpoint_dict["model_kwargs"])
model.load_state_dict(checkpoint_dict["model_state_dict"]) model.load_state_dict(checkpoint_dict["model_state_dict"], strict=False)
model_params = [] model_params = []
plots = [] plots = []
max_size = np.max(dims) max_size = np.max(dims)
# max_act_size = np.max(dims[1:]) # max_act_size = np.max(dims[1:])
angles = [None, None] # angles = [None, None]
weights = [None, None] # weights = [None, None]
for num, (layer_name, layer) in enumerate(model.named_children()): for num, (layer_name, layer) in enumerate(model.named_children()):
# each layer contains an "ONN" layer and an "activation" layer # each layer contains an "ONN" layer and an "activation" layer
# activation layer is approximately the same for all layers and nodes -> rotation by 90 degrees # activation layer is approximately the same for all layers and nodes -> rotation by 90 degrees
onn_weights = layer.ONN.weight.T onn_weights = layer.ONN.weight
onn_weights = onn_weights.detach().cpu().numpy() onn_weights = onn_weights.detach().cpu().numpy()
onn_values = np.abs(onn_weights).real onn_values = np.abs(onn_weights).real
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
act = layer.activation
act_values = np.ones((act.size, 1))
act_values = np.nan * act_values
act_angles = act.phase.unsqueeze(-1).detach().cpu().numpy()
...
# act_phi_bias = torch.pi * act.V_bias / (act.V_pi + 1e-8)
# act_phi_gain = torch.pi * (act.alpha * act.gain * act.responsivity) / (act.V_pi + 1e-8)
# xs = (0.01, 0.1, 1)
# act_values = np.zeros((act.size, len(xs)*2))
# act_angles = np.zeros((act.size, len(xs)*2))
# act_values[:,:] = np.nan
# act_angles[:,:] = np.nan
# for xi, x in enumerate(xs):
# phi_intermediate = act_phi_gain * x**2 + act_phi_bias
# act_resulting_gain = (
# 1j
# * torch.sqrt(1-act.alpha)
# * torch.exp(-0.5j * phi_intermediate)
# * torch.cos(0.5 * phi_intermediate)
# * x
# )
# act_resulting_gain = act_resulting_gain.detach().cpu().numpy()
# act_values[:, xi*2] = np.abs(act_resulting_gain).real
# act_angles[:, xi*2] = np.mod(np.angle(act_resulting_gain), 2*np.pi).real
# if angles[0] is None or angles[0] > np.min(onn_angles.flatten()):
# angles[0] = np.min(onn_angles.flatten())
# if angles[1] is None or angles[1] < np.max(onn_angles.flatten()):
# angles[1] = np.max(onn_angles.flatten())
# if weights[0] is None or weights[0] > np.min(onn_weights.flatten()):
# weights[0] = np.min(onn_weights.flatten())
# if weights[1] is None or weights[1] < np.max(onn_weights.flatten()):
# weights[1] = np.max(onn_weights.flatten())
model_params.append({layer_name: onn_weights}) model_params.append({layer_name: onn_weights})
plots.append({layer_name: (num, onn_values, onn_angles, act_values, act_angles)}) plots.append({layer_name: (num, onn_values, onn_angles)})#, act_values, act_angles)})
# fig, axs = plt.subplots(3, len(model_params)*2-1, figsize=(20, 5)) # fig, axs = plt.subplots(3, len(model_params)*2-1, figsize=(20, 5))
for plot in plots: for plot in plots:
layer_name, (num, onn_values, onn_angles, act_values, act_angles) = plot.popitem() layer_name, (num, onn_values, onn_angles) = plot.popitem()
# for_plot[:, :, 0] = (for_plot[:, :, 0] - angles[0]) / (angles[1] - angles[0])
# for_plot[:, :, 1] = (for_plot[:, :, 1] - weights[0]) / (weights[1] - weights[0])
onn_values = np.ma.array(onn_values, mask=np.isnan(onn_values))
onn_values = onn_values - np.min(onn_values)
onn_values = onn_values / np.max(onn_values)
act_values = np.ma.array(act_values, mask=np.isnan(act_values))
act_values = act_values - np.min(act_values)
act_values = act_values / np.max(act_values)
onn_values = onn_values
onn_values = pad_to_size(onn_values, (max_size, None))
act_values = act_values
act_values = pad_to_size(act_values, (max_size, 3))
onn_angles = onn_angles / np.pi
onn_angles = pad_to_size(onn_angles, (max_size, None))
act_angles = act_angles / np.pi
act_angles = pad_to_size(act_angles, (max_size, 3))
# onn_angles = onn_angles - np.min(onn_angles)
# onn_angles = onn_angles / np.max(onn_angles)
# act_angles = act_angles - np.min(act_angles)
# act_angles = act_angles / np.max(act_angles)
if num == 0: if num == 0:
value_img = np.concatenate((onn_values, act_values), axis=1) value_img = onn_values
angle_img = np.concatenate((onn_angles, act_angles), axis=1) angle_img = onn_angles
onn_angles = pad_to_size(onn_angles, (max_size, None))
onn_values = pad_to_size(onn_values, (max_size, None))
else: else:
value_img = np.concatenate((value_img, onn_values, act_values), axis=1) onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1))
angle_img = np.concatenate((angle_img, onn_angles, act_angles), axis=1) onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1))
value_img = np.concatenate((value_img, onn_values), axis=1)
angle_img = np.concatenate((angle_img, onn_angles), axis=1)
value_img = np.ma.array(value_img, mask=np.isnan(value_img))
angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img))
# from cmcrameri import cm
from cmap import Colormap as cm
import scicomap as sc
# from matplotlib import colors as mcolors
# alpha_map = mcolors.LinearSegmentedColormap(
# 'alphamap',
# {
# 'red': [(0, 0, 0), (1, 0, 0)],
# 'green': [(0, 0, 0), (1, 0, 0)],
# 'blue': [(0, 0, 0), (1, 0, 0)],
# 'alpha': [
# (0, 1, 1),
# # (0.2, 0.2, 0.1),
# (1, 0, 0)
# ]
# }
# )
# alpha_map.set_bad(color="#AAAAAA")
from mpl_toolkits.axes_grid1 import make_axes_locatable
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(16, 5))
# -np.inf to np.nan # fig.tight_layout()
# value_img[value_img == -np.inf] = np.nan dividers = map(make_axes_locatable, axs)
caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers))
# angle_img += move_to_location_in_size(onn_angles, ((max_size+3)*num, 0), img_overall_size)
# angle_img += move_to_location_in_size(act_angles, ((max_size+3)*(num+1) + 2, 0), img_overall_size)
from cmcrameri import cm
from matplotlib import colors as mcolors
alpha_map = mcolors.LinearSegmentedColormap(
'alphamap',
{
'red': [(0, 0, 0), (1, 0, 0)],
'green': [(0, 0, 0), (1, 0, 0)],
'blue': [(0, 0, 0), (1, 0, 0)],
'alpha': [
(0, 1, 1),
# (0.2, 0.2, 0.1),
(1, 0, 0)
]
}
)
alpha_map.set_bad(color="#AAAAAA")
fig, axs = plt.subplots(3, 1, sharex=True, sharey=True, figsize=(7, 8.5))
fig.tight_layout()
# masked_value_img = np.ma.masked_where(np.isnan(value_img), value_img) # masked_value_img = np.ma.masked_where(np.isnan(value_img), value_img)
masked_value_img = value_img masked_value_img = value_img
cmap = cm.batlowW cmap = cm('google:turbo').to_matplotlib()
# cmap = sc.ScicoSequential("rainbow").get_mpl_color_map()
cmap.set_bad(color="#AAAAAA") cmap.set_bad(color="#AAAAAA")
im_val = axs[0].imshow(masked_value_img, cmap=cmap) im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1)
fig.colorbar(im_val, cax=caxs[0], orientation="vertical")
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img) masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
cmap = cm.romaO # cmap = cm('crameri:romao').to_matplotlib()
# cmap = plt.get_cmap('puccs')
# cmap = sc.ScicoCircular("colorwheel").get_mpl_color_map()
cmap = cm('colorcet:CET_C8').to_matplotlib()
cmap.set_bad(color="#AAAAAA") cmap.set_bad(color="#AAAAAA")
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap) im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi)
im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap) cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)])
im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map) cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", ""])
# im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap)
# im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map)
axs[0].axis("off") axs[0].axis("off")
axs[1].axis("off") axs[1].axis("off")
axs[2].axis("off") # axs[2].axis("off")
axs[0].set_title("Values") axs[0].set_title("Values")
axs[1].set_title("Angles") axs[1].set_title("Angles")
axs[2].set_title("Values and Angles") # axs[2].set_title("Values and Angles")
... ...
if show:
plt.show() plt.show()
return fig
# model = models.regenerator(*dims, **model_kwargs) # model = models.regenerator(*dims, **model_kwargs)
if __name__ == "__main__": if __name__ == "__main__":
model_plot(".models/best_20250105_145719.tar") register_puccs_cmap()
if len(sys.argv) > 1:
model_plot(sys.argv[1])
else:
print("Please provide a model path as an argument")
# model_plot(".models/best_20250114_224234.tar")

View File

@@ -0,0 +1,102 @@
"x","L","a","b","R","G","B"
0.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
0.01,0.5374499525557803,0.604014067614707,0.6777967519386492,0.8956274406155226,0.27553288030331824,0.
0.02,0.5560867887452998,0.5680836759482211,0.6855816828789898,0.9019507507843885,0.318608215541461,0.
0.03,0.5746877595125583,0.5322224300667823,0.6933516322080414,0.907905487190649,0.3580633000693721,0.
0.04,0.5932314662487472,0.49647158484797804,0.7010976613543587,0.9134808162089558,0.3949845524063657,0.
0.05,0.6117000836392819,0.46086550613202343,0.7088123243737041,0.918668356138916,0.43002019316005363,0.
0.06,0.6300828534995973,0.4254249348741487,0.7164911273850869,0.923462736751354,0.4635961938811463,0.
0.07,0.6483763163456417,0.3901565406944371,0.7241326253017896,0.9278609626724071,0.49601354353255284,0.
0.08,0.6665840140182806,0.3550534951951814,0.7317382976124045,0.9318616057744784,0.5274983630587982,0.
0.09,0.6847162776119433,0.3200958808181962,0.7393124597949372,0.9354640163365924,0.5582303922647159,0.
0.1,0.7027902128942014,0.2852507189547545,0.7468622572263107,0.9386675557407496,0.5883604892249517,0.004034952213848706
0.11,0.7208298719332069,0.25047163906104203,0.7543977368741345,0.9414708123927996,0.6180221032545026,0.016031521294251994
0.12,0.7388665670611175,0.2156982733607376,0.7619319784446927,0.943870754968487,0.6473392272576862,0.029857267582036696
0.13,0.7569392765472108,0.18085547473834482,0.7694812638396673,0.9458617774020323,0.676432172396153,0.045365670193636125
0.14,0.7750950944867471,0.14585244938794778,0.7770652650825484,0.9474345911958609,0.7054219201084561,0.06017985923530026
0.15,0.793389684293558,0.11058188251425949,0.7847072337503834,0.9485749196617762,0.7344334940032564,0.07418869502646075
0.16,0.8117919447684838,0.07510373484536464,0.792394178330817,0.9492596163836376,0.7634480277996188,0.08767517868137237
0.17,0.8293050962981561,0.03629277424762101,0.799038155466063,0.9462308253550155,0.7922009241807345,0.10066327128139077
0.18,0.8213303100752708,-0.0062517290795987,0.7879999288492758,0.9088702681901394,0.7940579017644396,0.10139639009534024
0.19,0.8134831311534617,-0.048115463155645855,0.7771383286984362,0.8716809050191757,0.7954897210083888,0.10232311621802098
0.2,0.80558613530069,-0.0902449644291895,0.7662077749032042,0.8337524177888596,0.7965471523787845,0.10344968926026826
0.21,0.7975860185564765,-0.13292460297117392,0.7551344872795225,0.7947193410849823,0.7972381033243311,0.10477682283894393
0.22,0.7894147026971006,-0.17651756772919341,0.7438242359834689,0.7540941866826836,0.7975605026647324,0.10631182441371936
0.23,0.7809997374598548,-0.2214103719409295,0.7321767396537806,0.7112894518675287,0.7974995317311054,0.1080672415170634
0.24,0.7722646970273015,-0.2680107379394189,0.7200862142018722,0.6655745739336695,0.7970267795229349,0.11006041388465265
0.25,0.7631307298557146,-0.3167393290089981,0.7074435179925446,0.6160047476007512,0.7960993904970947,0.11231257117602686
0.26,0.7535192192483822,-0.36801555555407994,0.6941398344519211,0.5612859274945571,0.794659599537827,0.11484733363789801
0.27,0.7433557597838075,-0.42223636134393283,0.6800721760037781,0.4994862901720824,0.7926351396848288,0.11768844813479104
0.28,0.732575139048096,-0.479749646583324,0.6651502794883674,0.42731393423789277,0.7899410218414098,0.12085678487511567
0.29,0.7211269294461059,-0.5408244362880141,0.6493043460161184,0.3378265607222193,0.786483110019224,0.124366774034814
0.3,0.7090756028785993,-0.6051167807996883,0.6326236137723747,0.2098475715121697,0.7821998608677176,0.12819222127525928
0.31,0.7094510768540225,-0.6165036055456403,0.5630307498747129,0.15061488620640032,0.7845112116922692,0.21943537230975235
0.32,0.7174669421288304,-0.5917687864932311,0.4797229624661701,0.18766933782916642,0.7905828987725732,0.31091344246312086
0.33,0.7249009746435938,-0.5688293479200438,0.40246208306061504,0.21160609617940718,0.7962175427587832,0.38519766326885596
0.34,0.7317072855135611,-0.5478268906666535,0.3317250285377912,0.22717569971119178,0.8013847719431052,0.4490960048955565
0.35,0.7379328517830899,-0.5286164561226088,0.26702357292455026,0.23690087622812972,0.8061220291668977,0.5056371468159843
0.36,0.7436229063122554,-0.5110584677642499,0.20788761731555405,0.24226377668817778,0.8104638164122776,0.5563570758573497
0.37,0.7488251728809415,-0.4950056627547577,0.15382117501783654,0.24424372086048424,0.8144455902164638,0.6022301663745243
0.38,0.7535943992285348,-0.48028910419451787,0.10425526029155024,0.24352232677523483,0.818107753931944,0.6440238320299774
0.39,0.757994865186593,-0.4667104416936734,0.05852182167144754,0.240562414747303,0.8214980148949816,0.6824536572462205
0.4,0.7620994844391137,-0.4540446830999986,0.015863077249098356,0.2356325204239052,0.8246710357361025,0.7182393675419642
0.41,0.7659871096124125,-0.4420485102716773,-0.024540477496154123,0.22880568593963535,0.8276865975886148,0.7521146815529202
0.42,0.7697410958994951,-0.4304647113488041,-0.06355514164248566,0.21993360985514526,0.8306086550266585,0.7848331944479765
0.43,0.773446484628189,-0.4190308715098135,-0.10206473803580057,0.20858849290850018,0.833503273690861,0.8171544357676854
0.44,0.7771893686864673,-0.4074813310994203,-0.14096401824224686,0.1939295692427068,0.8364382500400466,0.8498448067259188
0.45,0.7810574093604746,-0.3955455908045306,-0.18116403397486242,0.17438366103820427,0.839483669055626,0.8836865023336339
0.46,0.7851360804917298,-0.3829599011818591,-0.2235531031349741,0.14679145002531463,0.8427091517444469,0.9194481212717681
0.47,0.789525027020907,-0.369416784561489,-0.26916682191206776,0.10278921007810798,0.8461971304126237,0.9580316568065935
0.48,0.7942371698732826,-0.35487637041943493,-0.3181394757087982,0.0013920913109500188,0.8499626968466341,0.9995866371771526
0.49,0.7773897680996302,-0.31852357140025195,-0.34537976514700053,0.10740420703601522,0.8254781216972907,1.
0.5,0.7604011244310231,-0.28211213216592784,-0.3722846952738428,0.1581725581872408,0.8008522647497104,1.
0.51,0.7433440454962605,-0.2455540169176899,-0.3992980063927199,0.19300141807932156,0.7761561224913385,1.
0.52,0.7262590833969331,-0.20893614020926626,-0.42635547610418184,0.2194621842292243,0.751443124097109,1.
0.53,0.709058602701224,-0.17207067467417486,-0.453595892719742,0.2405673704012788,0.7265803324554873,1.
0.54,0.6915768892539101,-0.1346024482921609,-0.48128169789479536,0.25788347992973676,0.701321051230534,1.
0.55,0.6736331627810209,-0.09614399811510127,-0.5096991935104321,0.2722888922216317,0.6753950894563805,1.
0.56,0.6551463184003872,-0.05652149358027936,-0.5389768254408652,0.28422807900785235,0.6486730893521468,1.
0.57,0.6361671326276888,-0.01584376303510615,-0.5690341788729347,0.293907374075009,0.6212117649042732,1.
0.58,0.6168396823565967,0.025580396234342995,-0.5996430791016598,0.301442767979156,0.5931976878638505,1.
0.59,0.5973210287815495,0.06741435793529688,-0.6305547881733555,0.30694603901024253,0.5648312189065924,1.
0.6,0.5777303704171711,0.10940264614179468,-0.661580531294122,0.3105418468883679,0.5362525958007331,1.
0.61,0.5581475370499237,0.15137416317967575,-0.6925938819599547,0.3123531986526998,0.5075386530652202,1.
0.62,0.5386227795100639,0.19322120739317136,-0.7235152578861672,0.31248922600720636,0.4787151440558522,1.
0.63,0.5191666876024412,0.23492108185347996,-0.754327887989376,0.31103663081260624,0.44973844514160927,1.
0.64,0.4996990584326256,0.2766456839100268,-0.7851587896650079,0.30803814950244496,0.4204116611935119,1.
0.65,0.479957679121191,0.3189570094767831,-0.8164232296840259,0.30343473603466015,0.390226489453496,1.
0.66,0.4600072725872886,0.3617163391430824,-0.8480187063016573,0.29717122075330515,0.3591178757512998,1.
0.67,0.44600100870220305,0.4113853615984094,-0.8697728377551008,0.3178994129506999,0.3295740682997879,1.
0.68,0.4574651571354146,0.44026390446569547,-0.8504539292487465,0.3842479358768364,0.3280946443367561,1.
0.69,0.4691809168948424,0.46977626401045774,-0.830711015748157,0.44293649140770447,0.3260767554252525,1.
0.7,0.4811696900083858,0.49997635259991063,-0.8105080314416201,0.49708450874457527,0.3234487047238236,1.
0.71,0.49350094811609174,0.5310391714342613,-0.7897279055963483,0.5485591109413528,0.3201099534066949,1.
0.72,0.5062548753068121,0.5631667067020758,-0.7682355153041539,0.5985798481027601,0.3159263917472715,1.
0.73,0.5195243020949684,0.5965928013272943,-0.7458744264238399,0.6480500606439057,0.31071717884730565,1.
0.74,0.5334043922713477,0.6315571758288618,-0.7224842728734379,0.6976685401842261,0.3042411890803418,1.
0.75,0.5479805812358602,0.6682750446095802,-0.697921082452685,0.7479712773579563,0.29618040787504757,1.
0.76,0.5633244502526606,0.7069267230777347,-0.6720642293775535,0.7993701361353484,0.28611136999256687,1.
0.77,0.5794956601139,0.7476624986056212,-0.6448131757501174,0.8521918014427678,0.2734527325942473,1.
0.78,0.5965429098573916,0.7906050455688622,-0.6160858559672187,0.9067003897516911,0.2573693489198746,1.
0.79,0.6145761476424179,0.8360313267658297,-0.5856969899409387,0.963334644317004,0.23648492980159264,1.
0.8,0.6232910688128902,0.859291371252556,-0.5300995185388214,1.,0.21867949406239662,0.9712088595948508
0.81,0.6159984336377875,0.8439887543380684,-0.44635440435952856,1.,0.21606849746358275,0.9041480210597966
0.82,0.6091642745073532,0.8296481879180277,-0.36787420852419694,1.,0.21421830096504035,0.8419706002336461
0.83,0.6025478038652375,0.8157644115969636,-0.2918938425681935,1.,0.21295365915197917,0.7823908751330636
0.84,0.5961857222953111,0.8024144366282877,-0.21883475834162458,0.9971140114799418,0.21220068235083267,0.7256713129328118
0.85,0.5900921771070883,0.7896279492437488,-0.1488594167412921,0.993273906363258,0.2118788857127918,0.671860243327784
0.86,0.5842771639541229,0.7774259239818333,-0.08208260304413262,0.9887084084529413,0.21191070453347688,0.6209624706933893
0.87,0.578741582584259,0.7658102488427286,-0.018514649521559012,0.9835846378805114,0.2122246941077346,0.5728987835613306
0.88,0.5734741590353537,0.7547572669288056,0.04197390858426542,0.9780378159372328,0.21275878699579343,0.5274829957183049
0.89,0.5684517008574971,0.7442183119942206,0.09964940221121898,0.9721670725313721,0.21346242315895625,0.4844270603851604
0.9,0.5636419856510335,0.7341257696545772,0.15488185789614228,0.9660363209686843,0.21429691147008262,0.4433660148378527
0.91,0.5590069340453534,0.7243997354573974,0.20810856081277884,0.9596781387247791,0.2152344151262528,0.4038812338146013
0.92,0.5545051525321143,0.7149533506766244,0.25980485409830323,0.9530986696850675,0.21625626438013962,0.3655130449917989
0.93,0.5500961975299247,0.705701749880514,0.3104351723857584,0.9462863346513658,0.21735046958786286,0.327780364198278
0.94,0.545740378056064,0.6965616468647046,0.36045530782708896,0.93921469089265,0.21851014470332586,0.29014917175372823
0.95,0.5414004092067859,0.6874548042588865,0.41029342232076466,0.9318478255642132,0.21973168075163751,0.2519897371806688
0.96,0.5370416605957644,0.6783085548415655,0.46034719456417006,0.9241434776436454,0.22101341980094052,0.2124579038400577
0.97,0.5326309593934517,0.6690532898786764,0.5109975653738162,0.9160532016485884,0.22235495330179011,0.17018252385769012
0.98,0.5281374148557197,0.6596241892863608,0.5625992691950712,0.90752576202319,0.22375597459867458,0.1223073280126531
0.99,0.5235317096396147,0.6499597345521199,0.615488972291106,0.8985077346125597,0.22521565729028564,0.05933950582860665
1.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
1 x L a b R G B
2 0. 0.5187848173343539 0.6399990176455989 0.67 0.8889427469969852 0.22673227640012172 0.
3 0.01 0.5374499525557803 0.604014067614707 0.6777967519386492 0.8956274406155226 0.27553288030331824 0.
4 0.02 0.5560867887452998 0.5680836759482211 0.6855816828789898 0.9019507507843885 0.318608215541461 0.
5 0.03 0.5746877595125583 0.5322224300667823 0.6933516322080414 0.907905487190649 0.3580633000693721 0.
6 0.04 0.5932314662487472 0.49647158484797804 0.7010976613543587 0.9134808162089558 0.3949845524063657 0.
7 0.05 0.6117000836392819 0.46086550613202343 0.7088123243737041 0.918668356138916 0.43002019316005363 0.
8 0.06 0.6300828534995973 0.4254249348741487 0.7164911273850869 0.923462736751354 0.4635961938811463 0.
9 0.07 0.6483763163456417 0.3901565406944371 0.7241326253017896 0.9278609626724071 0.49601354353255284 0.
10 0.08 0.6665840140182806 0.3550534951951814 0.7317382976124045 0.9318616057744784 0.5274983630587982 0.
11 0.09 0.6847162776119433 0.3200958808181962 0.7393124597949372 0.9354640163365924 0.5582303922647159 0.
12 0.1 0.7027902128942014 0.2852507189547545 0.7468622572263107 0.9386675557407496 0.5883604892249517 0.004034952213848706
13 0.11 0.7208298719332069 0.25047163906104203 0.7543977368741345 0.9414708123927996 0.6180221032545026 0.016031521294251994
14 0.12 0.7388665670611175 0.2156982733607376 0.7619319784446927 0.943870754968487 0.6473392272576862 0.029857267582036696
15 0.13 0.7569392765472108 0.18085547473834482 0.7694812638396673 0.9458617774020323 0.676432172396153 0.045365670193636125
16 0.14 0.7750950944867471 0.14585244938794778 0.7770652650825484 0.9474345911958609 0.7054219201084561 0.06017985923530026
17 0.15 0.793389684293558 0.11058188251425949 0.7847072337503834 0.9485749196617762 0.7344334940032564 0.07418869502646075
18 0.16 0.8117919447684838 0.07510373484536464 0.792394178330817 0.9492596163836376 0.7634480277996188 0.08767517868137237
19 0.17 0.8293050962981561 0.03629277424762101 0.799038155466063 0.9462308253550155 0.7922009241807345 0.10066327128139077
20 0.18 0.8213303100752708 -0.0062517290795987 0.7879999288492758 0.9088702681901394 0.7940579017644396 0.10139639009534024
21 0.19 0.8134831311534617 -0.048115463155645855 0.7771383286984362 0.8716809050191757 0.7954897210083888 0.10232311621802098
22 0.2 0.80558613530069 -0.0902449644291895 0.7662077749032042 0.8337524177888596 0.7965471523787845 0.10344968926026826
23 0.21 0.7975860185564765 -0.13292460297117392 0.7551344872795225 0.7947193410849823 0.7972381033243311 0.10477682283894393
24 0.22 0.7894147026971006 -0.17651756772919341 0.7438242359834689 0.7540941866826836 0.7975605026647324 0.10631182441371936
25 0.23 0.7809997374598548 -0.2214103719409295 0.7321767396537806 0.7112894518675287 0.7974995317311054 0.1080672415170634
26 0.24 0.7722646970273015 -0.2680107379394189 0.7200862142018722 0.6655745739336695 0.7970267795229349 0.11006041388465265
27 0.25 0.7631307298557146 -0.3167393290089981 0.7074435179925446 0.6160047476007512 0.7960993904970947 0.11231257117602686
28 0.26 0.7535192192483822 -0.36801555555407994 0.6941398344519211 0.5612859274945571 0.794659599537827 0.11484733363789801
29 0.27 0.7433557597838075 -0.42223636134393283 0.6800721760037781 0.4994862901720824 0.7926351396848288 0.11768844813479104
30 0.28 0.732575139048096 -0.479749646583324 0.6651502794883674 0.42731393423789277 0.7899410218414098 0.12085678487511567
31 0.29 0.7211269294461059 -0.5408244362880141 0.6493043460161184 0.3378265607222193 0.786483110019224 0.124366774034814
32 0.3 0.7090756028785993 -0.6051167807996883 0.6326236137723747 0.2098475715121697 0.7821998608677176 0.12819222127525928
33 0.31 0.7094510768540225 -0.6165036055456403 0.5630307498747129 0.15061488620640032 0.7845112116922692 0.21943537230975235
34 0.32 0.7174669421288304 -0.5917687864932311 0.4797229624661701 0.18766933782916642 0.7905828987725732 0.31091344246312086
35 0.33 0.7249009746435938 -0.5688293479200438 0.40246208306061504 0.21160609617940718 0.7962175427587832 0.38519766326885596
36 0.34 0.7317072855135611 -0.5478268906666535 0.3317250285377912 0.22717569971119178 0.8013847719431052 0.4490960048955565
37 0.35 0.7379328517830899 -0.5286164561226088 0.26702357292455026 0.23690087622812972 0.8061220291668977 0.5056371468159843
38 0.36 0.7436229063122554 -0.5110584677642499 0.20788761731555405 0.24226377668817778 0.8104638164122776 0.5563570758573497
39 0.37 0.7488251728809415 -0.4950056627547577 0.15382117501783654 0.24424372086048424 0.8144455902164638 0.6022301663745243
40 0.38 0.7535943992285348 -0.48028910419451787 0.10425526029155024 0.24352232677523483 0.818107753931944 0.6440238320299774
41 0.39 0.757994865186593 -0.4667104416936734 0.05852182167144754 0.240562414747303 0.8214980148949816 0.6824536572462205
42 0.4 0.7620994844391137 -0.4540446830999986 0.015863077249098356 0.2356325204239052 0.8246710357361025 0.7182393675419642
43 0.41 0.7659871096124125 -0.4420485102716773 -0.024540477496154123 0.22880568593963535 0.8276865975886148 0.7521146815529202
44 0.42 0.7697410958994951 -0.4304647113488041 -0.06355514164248566 0.21993360985514526 0.8306086550266585 0.7848331944479765
45 0.43 0.773446484628189 -0.4190308715098135 -0.10206473803580057 0.20858849290850018 0.833503273690861 0.8171544357676854
46 0.44 0.7771893686864673 -0.4074813310994203 -0.14096401824224686 0.1939295692427068 0.8364382500400466 0.8498448067259188
47 0.45 0.7810574093604746 -0.3955455908045306 -0.18116403397486242 0.17438366103820427 0.839483669055626 0.8836865023336339
48 0.46 0.7851360804917298 -0.3829599011818591 -0.2235531031349741 0.14679145002531463 0.8427091517444469 0.9194481212717681
49 0.47 0.789525027020907 -0.369416784561489 -0.26916682191206776 0.10278921007810798 0.8461971304126237 0.9580316568065935
50 0.48 0.7942371698732826 -0.35487637041943493 -0.3181394757087982 0.0013920913109500188 0.8499626968466341 0.9995866371771526
51 0.49 0.7773897680996302 -0.31852357140025195 -0.34537976514700053 0.10740420703601522 0.8254781216972907 1.
52 0.5 0.7604011244310231 -0.28211213216592784 -0.3722846952738428 0.1581725581872408 0.8008522647497104 1.
53 0.51 0.7433440454962605 -0.2455540169176899 -0.3992980063927199 0.19300141807932156 0.7761561224913385 1.
54 0.52 0.7262590833969331 -0.20893614020926626 -0.42635547610418184 0.2194621842292243 0.751443124097109 1.
55 0.53 0.709058602701224 -0.17207067467417486 -0.453595892719742 0.2405673704012788 0.7265803324554873 1.
56 0.54 0.6915768892539101 -0.1346024482921609 -0.48128169789479536 0.25788347992973676 0.701321051230534 1.
57 0.55 0.6736331627810209 -0.09614399811510127 -0.5096991935104321 0.2722888922216317 0.6753950894563805 1.
58 0.56 0.6551463184003872 -0.05652149358027936 -0.5389768254408652 0.28422807900785235 0.6486730893521468 1.
59 0.57 0.6361671326276888 -0.01584376303510615 -0.5690341788729347 0.293907374075009 0.6212117649042732 1.
60 0.58 0.6168396823565967 0.025580396234342995 -0.5996430791016598 0.301442767979156 0.5931976878638505 1.
61 0.59 0.5973210287815495 0.06741435793529688 -0.6305547881733555 0.30694603901024253 0.5648312189065924 1.
62 0.6 0.5777303704171711 0.10940264614179468 -0.661580531294122 0.3105418468883679 0.5362525958007331 1.
63 0.61 0.5581475370499237 0.15137416317967575 -0.6925938819599547 0.3123531986526998 0.5075386530652202 1.
64 0.62 0.5386227795100639 0.19322120739317136 -0.7235152578861672 0.31248922600720636 0.4787151440558522 1.
65 0.63 0.5191666876024412 0.23492108185347996 -0.754327887989376 0.31103663081260624 0.44973844514160927 1.
66 0.64 0.4996990584326256 0.2766456839100268 -0.7851587896650079 0.30803814950244496 0.4204116611935119 1.
67 0.65 0.479957679121191 0.3189570094767831 -0.8164232296840259 0.30343473603466015 0.390226489453496 1.
68 0.66 0.4600072725872886 0.3617163391430824 -0.8480187063016573 0.29717122075330515 0.3591178757512998 1.
69 0.67 0.44600100870220305 0.4113853615984094 -0.8697728377551008 0.3178994129506999 0.3295740682997879 1.
70 0.68 0.4574651571354146 0.44026390446569547 -0.8504539292487465 0.3842479358768364 0.3280946443367561 1.
71 0.69 0.4691809168948424 0.46977626401045774 -0.830711015748157 0.44293649140770447 0.3260767554252525 1.
72 0.7 0.4811696900083858 0.49997635259991063 -0.8105080314416201 0.49708450874457527 0.3234487047238236 1.
73 0.71 0.49350094811609174 0.5310391714342613 -0.7897279055963483 0.5485591109413528 0.3201099534066949 1.
74 0.72 0.5062548753068121 0.5631667067020758 -0.7682355153041539 0.5985798481027601 0.3159263917472715 1.
75 0.73 0.5195243020949684 0.5965928013272943 -0.7458744264238399 0.6480500606439057 0.31071717884730565 1.
76 0.74 0.5334043922713477 0.6315571758288618 -0.7224842728734379 0.6976685401842261 0.3042411890803418 1.
77 0.75 0.5479805812358602 0.6682750446095802 -0.697921082452685 0.7479712773579563 0.29618040787504757 1.
78 0.76 0.5633244502526606 0.7069267230777347 -0.6720642293775535 0.7993701361353484 0.28611136999256687 1.
79 0.77 0.5794956601139 0.7476624986056212 -0.6448131757501174 0.8521918014427678 0.2734527325942473 1.
80 0.78 0.5965429098573916 0.7906050455688622 -0.6160858559672187 0.9067003897516911 0.2573693489198746 1.
81 0.79 0.6145761476424179 0.8360313267658297 -0.5856969899409387 0.963334644317004 0.23648492980159264 1.
82 0.8 0.6232910688128902 0.859291371252556 -0.5300995185388214 1. 0.21867949406239662 0.9712088595948508
83 0.81 0.6159984336377875 0.8439887543380684 -0.44635440435952856 1. 0.21606849746358275 0.9041480210597966
84 0.82 0.6091642745073532 0.8296481879180277 -0.36787420852419694 1. 0.21421830096504035 0.8419706002336461
85 0.83 0.6025478038652375 0.8157644115969636 -0.2918938425681935 1. 0.21295365915197917 0.7823908751330636
86 0.84 0.5961857222953111 0.8024144366282877 -0.21883475834162458 0.9971140114799418 0.21220068235083267 0.7256713129328118
87 0.85 0.5900921771070883 0.7896279492437488 -0.1488594167412921 0.993273906363258 0.2118788857127918 0.671860243327784
88 0.86 0.5842771639541229 0.7774259239818333 -0.08208260304413262 0.9887084084529413 0.21191070453347688 0.6209624706933893
89 0.87 0.578741582584259 0.7658102488427286 -0.018514649521559012 0.9835846378805114 0.2122246941077346 0.5728987835613306
90 0.88 0.5734741590353537 0.7547572669288056 0.04197390858426542 0.9780378159372328 0.21275878699579343 0.5274829957183049
91 0.89 0.5684517008574971 0.7442183119942206 0.09964940221121898 0.9721670725313721 0.21346242315895625 0.4844270603851604
92 0.9 0.5636419856510335 0.7341257696545772 0.15488185789614228 0.9660363209686843 0.21429691147008262 0.4433660148378527
93 0.91 0.5590069340453534 0.7243997354573974 0.20810856081277884 0.9596781387247791 0.2152344151262528 0.4038812338146013
94 0.92 0.5545051525321143 0.7149533506766244 0.25980485409830323 0.9530986696850675 0.21625626438013962 0.3655130449917989
95 0.93 0.5500961975299247 0.705701749880514 0.3104351723857584 0.9462863346513658 0.21735046958786286 0.327780364198278
96 0.94 0.545740378056064 0.6965616468647046 0.36045530782708896 0.93921469089265 0.21851014470332586 0.29014917175372823
97 0.95 0.5414004092067859 0.6874548042588865 0.41029342232076466 0.9318478255642132 0.21973168075163751 0.2519897371806688
98 0.96 0.5370416605957644 0.6783085548415655 0.46034719456417006 0.9241434776436454 0.22101341980094052 0.2124579038400577
99 0.97 0.5326309593934517 0.6690532898786764 0.5109975653738162 0.9160532016485884 0.22235495330179011 0.17018252385769012
100 0.98 0.5281374148557197 0.6596241892863608 0.5625992691950712 0.90752576202319 0.22375597459867458 0.1223073280126531
101 0.99 0.5235317096396147 0.6499597345521199 0.615488972291106 0.8985077346125597 0.22521565729028564 0.05933950582860665
102 1. 0.5187848173343539 0.6399990176455989 0.67 0.8889427469969852 0.22673227640012172 0.

View File

@@ -26,28 +26,39 @@ global_settings = GlobalSettings(
) )
data_settings = DataSettings( data_settings = DataSettings(
# config_path="data/*-128-16384-1-0-0-0-0-PAM4-0-0.ini", # config_path="data/20250115-114319-128-16384-inf-100000-0-0-0-0-PAM4-1-0-10.ini", # no effects enabled - baseline
config_path="data/20250110-190528-128-16384-100000-0-0.2-17.0-0.058-PAM4-0-0.14-10.ini", # config_path = "data/20250115-233553-128-16384-1060.0-100000-0-0.2-17.0-0.058-PAM4-1.0-0.0-10.ini", # dispersion + slope only
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)], # config_path="data/20250115-115836-128-16384-60.0-100000-0-0.2-17-0.058-PAM4-1000-0.2-10.ini", # all linear effects enabled with realistic values + noise + pmd (delta_beta=0.2) + ortho_error = 0.1
# config_path="data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # a)
# config_path="data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # b)
# config_path="data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # c)
# config_path="data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # d) 10ps dgd
config_path="data/20250120-105720-128-16384-inf-100000-0-0.2-17-0.058-PAM4-0-0-10.ini", # d) 10ns
# config_path="data/20250114-215547-128-16384-60.0-100000-1.15-0.2-17-0.058-PAM4-1-0-10.ini", # with gamma=1.15, 2.5dBm launch power, no pmd
dtype="complex64", dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=4, # study: single_core_regen_20241123_011232 symbols=4, # study: single_core_regen_20241123_011232 -> taps spread over 4 symbols @ 10GBd
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y)) # output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
output_size=20, # study: single_core_regen_20241123_011232 (model_input_dim/2) output_size=20, # = model_input_dim/2 study: single_core_regen_20241123_011232
shuffle=True, shuffle=False,
drop_first=64, drop_first=256,
drop_last=256,
train_split=0.8, train_split=0.8,
randomise_polarisations=False, randomise_polarisations=False,
polarisations=True, polarisations=False,
# cross_pol_interference=0.01,
osnr=16, #16dB due to amplification with NF 5 osnr=16, #16dB due to amplification with NF 5
) )
pytorch_settings = PytorchSettings( pytorch_settings = PytorchSettings(
epochs=1000, epochs=1000,
batchsize=2**14, batchsize=2**13,
device="cuda", device="cuda",
dataloader_workers=24, dataloader_workers=32,
dataloader_prefetch=8, dataloader_prefetch=4,
summary_dir=".runs", summary_dir=".runs",
write_every=2**5, write_every=2**5,
save_models=True, save_models=True,
@@ -65,16 +76,13 @@ model_settings = ModelSettings(
# "n_hidden_nodes_3": 4, # "n_hidden_nodes_3": 4,
# "n_hidden_nodes_4": 2, # "n_hidden_nodes_4": 2,
}, },
model_activation_func="phase_shift", model_activation_func="EOActivation",
dropout_prob=0, dropout_prob=0,
model_layer_function="ONNRect", model_layer_function="ONNRect",
model_layer_kwargs={"square": True}, model_layer_kwargs={"square": True},
scale=2.0, scale=2.0,
model_layer_parametrizations=[ model_layer_parametrizations=[
{ # EOactivation
"tensor_name": "weight",
"parametrization": util.complexNN.energy_conserving,
},
{ {
"tensor_name": "alpha", "tensor_name": "alpha",
"parametrization": util.complexNN.clamp, "parametrization": util.complexNN.clamp,
@@ -83,54 +91,20 @@ model_settings = ModelSettings(
"max": 1, "max": 1,
}, },
}, },
# ONNRect
{ {
"tensor_name": "gain", "tensor_name": "weight",
"parametrization": util.complexNN.clamp, "parametrization": torch.nn.utils.parametrizations.orthogonal,
"kwargs": {
"min": 0,
"max": None,
},
},
{
"tensor_name": "phase_bias",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2 * torch.pi,
},
}, },
# Scale
{ {
"tensor_name": "scale", "tensor_name": "scale",
"parametrization": util.complexNN.clamp, "parametrization": util.complexNN.clamp,
"kwargs": { "kwargs": {
"min": 0, "min": 0,
"max": 2, "max": 10,
},
},
{
"tensor_name": "angle",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": -torch.pi,
"max": torch.pi,
},
},
# {
# "tensor_name": "scale",
# "parametrization": util.complexNN.clamp,
# },
# {
# "tensor_name": "bias",
# "parametrization": util.complexNN.clamp,
# },
# {
# "tensor_name": "V",
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
# },
{
"tensor_name": "loss",
"parametrization": util.complexNN.clamp,
}, },
}
], ],
) )
@@ -145,107 +119,19 @@ optimizer_settings = OptimizerSettings(
scheduler="ReduceLROnPlateau", scheduler="ReduceLROnPlateau",
scheduler_kwargs={ scheduler_kwargs={
"patience": 2**6, "patience": 2**6,
"factor": 0.75, "factor": 0.5,
# "threshold": 1e-3, # "threshold": 1e-3,
"min_lr": 1e-6, "min_lr": 1e-6,
"cooldown": 10, "cooldown": 10,
}, },
early_stopping=True,
early_stop_kwargs={
"threshold": 1e-06,
"plateau": 2**7,
}
) )
def save_dict_to_file(dictionary, filename):
"""
Save the best dictionary to a JSON file.
:param best: Dictionary containing the best training results.
:type best: dict
:param filename: Path to the JSON file where the dictionary will be saved.
:type filename: str
"""
with open(filename, "w") as f:
json.dump(dictionary, f, indent=4)
def sweep_lengths(*lengths, model=None, data_glob: str = None, strategy="newest"):
assert model is not None, "Model must be provided."
assert data_glob is not None, "Data glob must be provided."
model = model
fiber_ins = {}
fiber_outs = {}
regens = {}
timestampss = {}
trainer = RegenerationTrainer(
checkpoint_path=model,
)
trainer.define_model()
for length in lengths:
data_glob_length = data_glob.replace("{length}", str(length))
files = list(Path.cwd().glob(data_glob_length))
if len(files) == 0:
continue
if strategy == "newest":
sorted_kwargs = {
"key": lambda x: x.stat().st_mtime,
"reverse": True,
}
elif strategy == "oldest":
sorted_kwargs = {
"key": lambda x: x.stat().st_mtime,
"reverse": False,
}
else:
raise ValueError(f"Unknown strategy {strategy}.")
file = sorted(files, **sorted_kwargs)[0]
loader, _ = trainer.get_sliced_data(override={"config_path": file})
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
fiber_ins[length] = fiber_in
fiber_outs[length] = fiber_out
regens[length] = regen
timestampss[length] = timestamps
data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0])
channel_names = ["" for _ in range(2 * len(timestampss.keys()) + 2)]
data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128
data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square()
channel_names[1] = "fiber in x"
for li, length in enumerate(timestampss.keys()):
data[2 + 2 * li, 0, :] = timestampss[length] / 128
data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
data[2 + 2 * li + 1, 0, :] = timestampss[length] / 128
data[2 + 2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
channel_names[2 + 2 * li + 1] = f"regen x {length}"
channel_names[2 + 2 * li] = f"fiber out x {length}"
# get current backend
backend = matplotlib.get_backend()
matplotlib.use("TkCairo")
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
print_attrs = ("channel_name", "success", "min_area")
with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}):
for result in eye.eye_stats:
print_dict = {attr: result[attr] for attr in print_attrs}
rprint(print_dict)
rprint()
eye.plot(all_stats=False)
matplotlib.use(backend)
if __name__ == "__main__": if __name__ == "__main__":
# lengths = range(90000, 100000+10000, 10000)
# lengths = [100000]
# sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest")
trainer = RegenerationTrainer( trainer = RegenerationTrainer(
global_settings=global_settings, global_settings=global_settings,
@@ -253,83 +139,15 @@ if __name__ == "__main__":
pytorch_settings=pytorch_settings, pytorch_settings=pytorch_settings,
model_settings=model_settings, model_settings=model_settings,
optimizer_settings=optimizer_settings, optimizer_settings=optimizer_settings,
# checkpoint_path=".models/best_20250104_191428.tar", checkpoint_path=".models/best_20250117_144001.tar",
reset_epoch=True, new_model=True,
# settings_override={ settings_override={
# "data_settings": { "data_settings": data_settings.__dict__,
# "config_path": "data/20241229-163*-128-16384-100000-*.ini",
# "polarisations": True,
# },
# "model_settings": {
# "scale": 2.0,
# }
# }
# "optimizer_settings": { # "optimizer_settings": {
# "optimizer_kwargs": { # "early_stop_kwargs":{
# "lr": 0.01, # "plateau": 2**8,
# },
# } # }
# } # }
# 20241202_143149 }
) )
trainer.train() trainer.train()
# from hypertraining.lighning_models import regenerator, regeneratorData
# import lightning as L
# model = regenerator(
# 2 * data_settings.output_size,
# *model_settings.overrides["hidden_layer_dims"],
# model_settings.output_dim,
# layer_function=getattr(util.complexNN, model_settings.model_layer_function),
# layer_func_kwargs=model_settings.model_layer_kwargs,
# act_function=getattr(util.complexNN, model_settings.model_activation_func),
# act_func_kwargs=None,
# parametrizations=model_settings.model_layer_parametrizations,
# dtype=getattr(torch, data_settings.dtype),
# dropout_prob=model_settings.dropout_prob,
# scale_layers=model_settings.scale,
# optimizer=getattr(torch.optim, optimizer_settings.optimizer),
# optimizer_kwargs=optimizer_settings.optimizer_kwargs,
# lr_scheduler=getattr(torch.optim.lr_scheduler, optimizer_settings.scheduler),
# lr_scheduler_kwargs=optimizer_settings.scheduler_kwargs,
# )
# dm = regeneratorData(
# config_globs=data_settings.config_path,
# output_symbols=data_settings.symbols,
# output_dim=data_settings.output_size,
# dtype=getattr(torch, data_settings.dtype),
# drop_first=data_settings.drop_first,
# shuffle=data_settings.shuffle,
# train_split=data_settings.train_split,
# batch_size=pytorch_settings.batchsize,
# loader_settings={
# "num_workers": pytorch_settings.dataloader_workers,
# "prefetch_factor": pytorch_settings.dataloader_prefetch,
# "pin_memory": True,
# "drop_last": True,
# },
# seed=global_settings.seed,
# )
# # writer = L.SummaryWriter(pytorch_settings.summary_dir + f"/{datetime.now().strftime('%Y%m%d_%H%M%S')}")
# # from torch.utils.tensorboard import SummaryWriter
# subdir = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# # writer = SummaryWriter(pytorch_settings.summary_dir + f"/{subdir}")
# logger = L.pytorch.loggers.TensorBoardLogger(pytorch_settings.summary_dir, name=subdir, log_graph=True)
# trainer = L.Trainer(
# fast_dev_run=False,
# # max_epochs=pytorch_settings.epochs,
# max_epochs=2,
# enable_checkpointing=True,
# default_root_dir=f".models/{subdir}/",
# logger=logger,
# )
# trainer.fit(model, dm)

View File

@@ -12,6 +12,7 @@ Full license text in LICENSE file
""" """
import configparser import configparser
# import copy
from datetime import datetime from datetime import datetime
import hashlib import hashlib
from pathlib import Path from pathlib import Path
@@ -40,7 +41,7 @@ alpha = 0.2
D = 17 D = 17
S = 0.058 S = 0.058
bireflength = 10 bireflength = 10
max_delta_beta = 0.14 pmd_q = 0.2
; birefseed = 0xC0FFEE ; birefseed = 0xC0FFEE
[signal] [signal]
@@ -195,10 +196,14 @@ class pam_generator:
def initialize_fiber_and_data(config): def initialize_fiber_and_data(config):
f0 = config["glova"].get("f0", None)
if f0 is None:
f0 = 299792458/(config["glova"].get("lambda0", 1550)*1e-9)
config["glova"]["f0"] = f0
py_glova = pypho.setup( py_glova = pypho.setup(
nos=config["glova"]["nos"], nos=config["glova"]["nos"],
sps=config["glova"]["sps"], sps=config["glova"]["sps"],
f0=config["glova"]["f0"], f0=f0,
symbolrate=config["glova"]["symbolrate"], symbolrate=config["glova"]["symbolrate"],
wisdom_dir=config["glova"]["wisdom_dir"], wisdom_dir=config["glova"]["wisdom_dir"],
flags=config["glova"]["flags"], flags=config["glova"]["flags"],
@@ -216,7 +221,9 @@ def initialize_fiber_and_data(config):
symbolsrc = pypho.symbols( symbolsrc = pypho.symbols(
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"] py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
) )
laser = pypho.lasmod(py_glova, power=config["signal"]["laser_power"], Df=0, theta=np.pi / 4) laserx = pypho.lasmod(py_glova, power=0, Df=0, theta=np.pi/4)
# lasery = pypho.lasmod(py_glova, power=0, Df=25, theta=0)
modulator = pam_generator( modulator = pam_generator(
py_glova, py_glova,
mod_depth=config["signal"]["mod_depth"], mod_depth=config["signal"]["mod_depth"],
@@ -232,7 +239,12 @@ def initialize_fiber_and_data(config):
symbols_y[:3] = 0 symbols_y[:3] = 0
# symbols_x += 1 # symbols_x += 1
cw = laser()
cw = laserx()
# cwy = lasery()
# cw[0]['E'][0] = cw[0]['E'][0]
# cw[0]['E'][1] = cwy[0]['E'][0]
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y)) source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
@@ -251,13 +263,41 @@ def initialize_fiber_and_data(config):
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))] # source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
## side channels
# df = 100
# signal_power = pypho.functions.W_to_dBm(np.sum(pypho.functions.getpower_W(source_signal[0]["E"])))
# symbols_x_side = symbolsrc(pattern="random")
# symbols_y_side = symbolsrc(pattern="random")
# symbols_x_side[:3] = 0
# symbols_y_side[:3] = 0
# cw_left = laser(Df=-df)
# source_signal_left = modulator(E=cw_left, symbols=(symbols_x_side, symbols_y_side))
# cw_right = laser(Df=df)
# source_signal_right = modulator(E=cw_right, symbols=(symbols_y_side, symbols_x_side))
E_in_pure = source_signal[0]["E"]
nf = py_edfa.NF nf = py_edfa.NF
source_signal = py_edfa(E=source_signal, NF=0) pmean = py_edfa.Pmean
py_edfa.NF = nf
# ideal amplification to launch power into fiber
source_signal = py_edfa(E=source_signal, NF=0, Pmean=config["signal"]["laser_power"])
# source_signal_left = py_edfa(E=source_signal_left, NF=0, Pmean=config["signal"]["laser_power"])
# source_signal_right = py_edfa(E=source_signal_right, NF=0, Pmean=config["signal"]["laser_power"])
# source_signal[0]["E"][0] += source_signal_left[0]["E"][0] + source_signal_right[0]["E"][0]
# source_signal[0]["E"][1] += source_signal_left[0]["E"][1] + source_signal_right[0]["E"][1]
c_data.E_in = source_signal[0]["E"] c_data.E_in = source_signal[0]["E"]
noise = source_signal[0]["noise"] noise = source_signal[0]["noise"]
py_edfa.NF = nf
py_edfa.Pmean = pmean
py_fiber = pypho.fiber( py_fiber = pypho.fiber(
glova=py_glova, glova=py_glova,
l=config["fiber"]["length"], l=config["fiber"]["length"],
@@ -265,20 +305,29 @@ def initialize_fiber_and_data(config):
gamma=config["fiber"]["gamma"], gamma=config["fiber"]["gamma"],
D=config["fiber"]["d"], D=config["fiber"]["d"],
S=config["fiber"]["s"], S=config["fiber"]["s"],
phi_max=0.02,
) )
if config["fiber"].get("birefsteps", 0) > 0:
config["fiber"]["birefsteps"] = config["fiber"].get(
"birefsteps", config["fiber"]["length"] // config["fiber"].get("bireflength", config["fiber"]["length"])
)
if config["fiber"]["birefsteps"] > 0:
config["fiber"]["bireflength"] = config["fiber"].get("bireflength", config["fiber"]["length"] / config["fiber"]["birefsteps"])
seed = config["fiber"].get("birefseed", (int(time.time() * 1000)) % 2**32) seed = config["fiber"].get("birefseed", (int(time.time() * 1000)) % 2**32)
py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre( py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre(
py_fiber.l, config["fiber"]["length"],
py_fiber.l / config["fiber"]["birefsteps"], config["fiber"]["bireflength"],
# maxDeltaD=config["fiber"]["d"]/5, maxDeltaBeta=config["fiber"].get("pmd_q", 0)/np.sqrt(2*config["fiber"]["bireflength"]),
maxDeltaBeta=config["fiber"].get("max_delta_beta", 0),
seed=seed, seed=seed,
) )
c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200) elif (dgd := config['fiber'].get('dgd', 0)) > 0:
py_fiber.birefarray = [
pypho.birefringence_segment(z_point=0, angle=np.pi/2, delta_beta=1000*dgd/config["fiber"]["length"])
]
c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=config["fiber"]["length"] if py_fiber.gamma == 0 else 200)
c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova) c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova)
return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y) return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y), py_glova, E_in_pure
def save_data(data, config, **metadata): def save_data(data, config, **metadata):
@@ -316,8 +365,11 @@ def save_data(data, config, **metadata):
f"D = {config['fiber']['d']}", f"D = {config['fiber']['d']}",
f"S = {config['fiber']['s']}", f"S = {config['fiber']['s']}",
f"birefsteps = {config['fiber'].get('birefsteps', 0)}", f"birefsteps = {config['fiber'].get('birefsteps', 0)}",
f"max_delta_beta = {config['fiber'].get('max_delta_beta', 0)}", f"pmd_q = {config['fiber'].get('pmd_q', 0)}",
f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set", f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set",
f"dgd = {config['fiber'].get('dgd', 0)}",
f"ortho_error = {config['fiber'].get('ortho_error', 0)}",
f"pol_error = {config['fiber'].get('pol_error', 0)}",
"", "",
"[signal]", "[signal]",
f"seed = {hex(seed)}" if seed else "; seed = not set", f"seed = {hex(seed)}" if seed else "; seed = not set",
@@ -346,24 +398,12 @@ def save_data(data, config, **metadata):
save_file = f"{config_hash}.h5" save_file = f"{config_hash}.h5"
config_content += f'"{str(save_file)}"\n' config_content += f'"{str(save_file)}"\n'
filename_components = ( config_filename:Path = create_config_filename(config, data_dir, timestamp)
timestamp.strftime("%Y%m%d-%H%M%S"), while config_filename.exists():
config["glova"]["sps"], time.sleep(1)
config["glova"]["nos"], config_filename = create_config_filename(config, data_dir=data_dir)
config["signal"]["osnr"],
config["fiber"]["length"],
config["fiber"]["gamma"],
config["fiber"]["alpha"],
config["fiber"]["d"],
config["fiber"]["s"],
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
config["fiber"].get("birefsteps", 0),
config["fiber"].get("max_delta_beta", 0),
int(config["glova"]["symbolrate"] / 1e9),
)
lookup_file = "-".join(map(str, filename_components)) + ".ini"
config_filename = data_dir / lookup_file
with open(config_filename, "w") as f: with open(config_filename, "w") as f:
f.write(config_content) f.write(config_content)
@@ -376,11 +416,31 @@ def save_data(data, config, **metadata):
outfile.attrs[key] = value outfile.attrs[key] = value
# np.save(save_dir / save_file, save_data) # np.save(save_dir / save_file, save_data)
print("Saved config to", config_filename) # print("Saved config to", config_filename)
print("Saved data to", save_dir / save_file) # print("Saved data to", save_dir / save_file)
return config_filename return config_filename
def create_config_filename(config, data_dir:Path, timestamp=None):
if timestamp is None:
timestamp = datetime.now()
filename_components = (
timestamp.strftime("%Y%m%d-%H%M%S"),
config["glova"]["sps"],
config["glova"]["nos"],
config["signal"]["osnr"],
config["fiber"]["length"],
config["fiber"]["gamma"],
config["fiber"]["alpha"],
config["fiber"]["d"],
config["fiber"]["s"],
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
config["fiber"].get("birefsteps", 0),
config["fiber"].get("pmd_q", 0),
int(config["glova"]["symbolrate"] / 1e9),
)
lookup_file = "-".join(map(str, filename_components)) + ".ini"
return data_dir / lookup_file
def length_loop(config, lengths, save=True): def length_loop(config, lengths, save=True):
lengths = sorted(lengths) lengths = sorted(lengths)
@@ -388,7 +448,7 @@ def length_loop(config, lengths, save=True):
print(f"\nGenerating data for fiber length {length}m") print(f"\nGenerating data for fiber length {length}m")
config["fiber"]["length"] = length config["fiber"]["length"] = length
cfiber, cdata, noise, edfa = initialize_fiber_and_data(config) cfiber, cdata, noise, edfa, symbols, py_glova = initialize_fiber_and_data(config)
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in)) mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
cfiber() cfiber()
@@ -416,51 +476,49 @@ def single_run_with_plot(config, save=True):
in_out_eyes(cfiber, cdata, show_pols=False) in_out_eyes(cfiber, cdata, show_pols=False)
return config_filename return config_filename
def single_run(config, save=True):
cfiber, cdata, noise, edfa, symbols = initialize_fiber_and_data(config)
# mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in)) def single_run(config, save=True, silent=True):
# print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)") cfiber, cdata, noise, edfa, symbols, glova, E_in = initialize_fiber_and_data(config)
# estimate osnr
# noise_power = np.mean(noise)
# osnr_lin = mean_power_in / noise_power - 1
# osnr = 10 * np.log10(osnr_lin)
# print(f"Estimated OSNR: {osnr:.3f} dB")
# transmit
cfiber() cfiber()
# mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out)) # amplify
# print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)")
# noise = noise * np.exp(-cfiber.params.l * cfiber.params.alpha)
# estimate osnr
# noise_power = np.mean(noise)
# osnr_lin = mean_power_out / noise_power - 1
# osnr = 10 * np.log10(osnr_lin)
# print(f"Estimated OSNR: {osnr:.3f} dB")
E_tmp = [{"E": cdata.E_out, "noise": noise}] E_tmp = [{"E": cdata.E_out, "noise": noise}]
E_tmp = edfa(E=E_tmp) E_tmp = edfa(E=E_tmp)
# rotate
# ortho error
ortho_error = config["fiber"].get("ortho_error", 0)
E_tmp[0]["E"] = np.stack((
E_tmp[0]["E"][0] * np.cos(ortho_error/2) + E_tmp[0]["E"][1] * np.sin(ortho_error/2),
E_tmp[0]["E"][0] * np.sin(ortho_error/2) + E_tmp[0]["E"][1] * np.cos(ortho_error/2)
), axis=0)
pol_error = config['fiber'].get('pol_error', 0)
E_tmp[0]["E"] = np.stack((
E_tmp[0]["E"][0] * np.cos(pol_error) - E_tmp[0]["E"][1] * np.sin(pol_error),
E_tmp[0]["E"][0] * np.sin(pol_error) + E_tmp[0]["E"][1] * np.cos(pol_error)
), axis=0)
# output
cdata.E_out = E_tmp[0]["E"] cdata.E_out = E_tmp[0]["E"]
# noise = E_tmp[0]["noise"]
# mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out))
# print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)")
# estimate osnr
# noise_power = np.mean(noise)
# osnr_lin = mean_power_amp / noise_power - 1
# osnr = 10 * np.log10(osnr_lin)
# print(f"Estimated OSNR: {osnr:.3f} dB")
config_filename = None config_filename = None
symbols = np.array(symbols) symbols = np.array(symbols)
if save: if save:
config_filename = save_data(cdata, config, **{"symbols": symbols}) config_filename = save_data(cdata, config, **{"symbols": symbols})
return cfiber,cdata,config_filename if not silent:
print(f"Saved config to {config_filename}")
return cfiber, cdata, config_filename
def in_out_eyes(cfiber, cdata, show_pols=False): def in_out_eyes(cfiber, cdata, show_pols=False):

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -441,8 +441,7 @@ class input_rotator(nn.Module):
# return out # return out
#### as defined by zhang et al #### as defined by zhang et alas
class DropoutComplex(nn.Module): class DropoutComplex(nn.Module):
def __init__(self, p=0.5): def __init__(self, p=0.5):
@@ -464,7 +463,7 @@ class Scale(nn.Module):
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32)) self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
def forward(self, x): def forward(self, x):
return x * self.scale return x * torch.sqrt(self.scale)
def __repr__(self): def __repr__(self):
return f"Scale({self.size})" return f"Scale({self.size})"
@@ -546,35 +545,31 @@ class EOActivation(nn.Module):
raise ValueError("Size must be specified") raise ValueError("Size must be specified")
self.size = size self.size = size
self.alpha = nn.Parameter(torch.rand(size)) self.alpha = nn.Parameter(torch.rand(size))
self.V_bias = nn.Parameter(torch.rand(size))
self.gain = nn.Parameter(torch.rand(size)) self.gain = nn.Parameter(torch.rand(size))
# if bias: self.V_bias = nn.Parameter(torch.rand(size))
# self.phase_bias = nn.Parameter(torch.zeros(size)) # self.register_buffer("gain", torch.ones(size))
# else: # self.register_buffer("responsivity", torch.ones(size))
# self.register_buffer("phase_bias", torch.zeros(size)) # self.register_buffer("V_pi", torch.ones(size))
# self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
self.register_buffer("responsivity", torch.ones(size)*0.9)
self.register_buffer("V_pi", torch.ones(size)*3)
self.reset_weights() self.reset_weights()
def reset_weights(self): def reset_weights(self):
if "alpha" in self._parameters: if "alpha" in self._parameters:
self.alpha.data = torch.rand(self.size) self.alpha.data = torch.rand(self.size)
if "V_pi" in self._parameters: # if "V_pi" in self._parameters:
self.V_pi.data = torch.rand(self.size)*3 # self.V_pi.data = torch.rand(self.size)*3
if "V_bias" in self._parameters: if "V_bias" in self._parameters:
self.V_bias.data = torch.randn(self.size) self.V_bias.data = torch.randn(self.size)
if "gain" in self._parameters: if "gain" in self._parameters:
self.gain.data = torch.rand(self.size) self.gain.data = torch.rand(self.size)
if "responsivity" in self._parameters: # if "responsivity" in self._parameters:
self.responsivity.data = torch.ones(self.size)*0.9 # self.responsivity.data = torch.ones(self.size)*0.9
# if "bias" in self._parameters: # if "bias" in self._parameters:
# self.phase_bias.data = torch.zeros(self.size) # self.phase_bias.data = torch.zeros(self.size)
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8) phi_b = torch.pi * self.V_bias# / (self.V_pi)
g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8) g_phi = torch.pi * (self.alpha * self.gain)# * self.responsivity)# / (self.V_pi)
intermediate = g_phi * x.abs().square() + phi_b intermediate = g_phi * x.abs().square() + phi_b
return ( return (
1j 1j

View File

@@ -0,0 +1,105 @@
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
# This software is licensed according to the "BSD 2-clause" license.
import hashlib
import h5py
import numpy as _np
from scipy.interpolate import interp1d as _interp1d
from scipy.ndimage import gaussian_filter as _gaussian_filter
from ._brescount import bres_curve_count as _bres_curve_count
from pathlib import Path
__all__ = ['grid_count']
def grid_count(y, window_size, offset=0, size=None, fuzz=True, blur=0, bounds=None):
"""
Parameters
----------
`y` is the 1-d array of signal samples.
`window_size` is the number of samples to show horizontally in the
eye diagram. Typically this is twice the number of samples in a
"symbol" (i.e. in a data bit).
`offset` is the number of initial samples to skip before computing
the eye diagram. This allows the overall phase of the diagram to
be adjusted.
`size` must be a tuple of two integers. It sets the size of the
array of counts, (height, width). The default is (800, 640).
`fuzz`: If True, the values in `y` are reinterpolated with a
random "fuzz factor" before plotting in the eye diagram. This
reduces an aliasing-like effect that arises with the use of
Bresenham's algorithm.
`bounds` must be a tuple of two floating point values, (ymin, ymax).
These set the y range of the returned array. If not given, the
bounds are `(y.min() - 0.05*A, y.max() + 0.05*A)`, where `A` is
`y.max() - y.min()`.
Return Value
------------
Returns a numpy array of integers.
"""
# hash input params
param_ob = (y, window_size, offset, size, fuzz, blur, bounds)
param_hash = hashlib.md5(str(param_ob).encode()).hexdigest()
cache_dir = Path.home()/".eyediagram"/".cache"
cache_dir.mkdir(parents=True, exist_ok=True)
if (cache_dir/param_hash).is_file():
try:
with h5py.File(cache_dir/param_hash, "r") as infile:
counts = infile["counts"][:]
if counts.len() != 0:
return counts
except:
pass
if size is None:
size = (800, 640)
height, width = size
dt = width / window_size
counts = _np.zeros((width, height), dtype=_np.int32)
if bounds is None:
ymin = y.min()
ymax = y.max()
yamp = ymax - ymin
ymin = ymin - 0.05*yamp
ymax = ymax + 0.05*yamp
ymax = _np.ceil(ymax*10)/10
ymin = _np.floor(ymin*10)/10
else:
ymin, ymax = bounds
start = offset
while start + window_size < len(y):
end = start + window_size
yy = y[start:end+1]
k = _np.arange(len(yy))
xx = dt*k
if fuzz:
f = _interp1d(xx, yy, kind='cubic')
jiggle = dt*(_np.random.beta(a=3, b=3, size=len(xx)-2) - 0.5)
xx[1:-1] += jiggle
yd = f(xx)
else:
yd = yy
iyd = (height * (yd - ymin)/(ymax - ymin)).astype(_np.int32)
_bres_curve_count(xx.astype(_np.int32), iyd, counts)
start = end
if blur != 0:
counts = _gaussian_filter(counts, sigma=blur)
with h5py.File(cache_dir/param_hash, "w") as outfile:
outfile.create_dataset("data", data=counts)
return counts

View File

@@ -25,13 +25,14 @@ import multiprocessing as mp
# def __len__(self): # def __len__(self):
# return len(self.indices) # return len(self.indices)
def load_from_file(datapath): def load_from_file(datapath):
if str(datapath).endswith('.h5'): if str(datapath).endswith(".h5"):
symbols = None symbols = None
with h5py.File(datapath, "r") as infile: with h5py.File(datapath, "r") as infile:
data = infile["data"][:] data = infile["data"][:]
try: try:
symbols = infile["symbols"][:] symbols = np.swapaxes(infile["symbols"][:], 0, 1)
except KeyError: except KeyError:
pass pass
else: else:
@@ -40,7 +41,7 @@ def load_from_file(datapath):
return data, symbols return data, symbols
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=1, device=None, dtype=None): def load_data(config_path, skipfirst=0, skiplast=0, symbols=None, real=False, normalize=1, device=None, dtype=None):
filepath = Path(config_path) filepath = Path(config_path)
filepath = filepath.parent.glob(filepath.name) filepath = filepath.parent.glob(filepath.name)
config = configparser.ConfigParser() config = configparser.ConfigParser()
@@ -58,12 +59,20 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=1, d
data, orig_symbols = load_from_file(datapath) data, orig_symbols = load_from_file(datapath)
data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)] data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps - skiplast * sps)]
orig_symbols = orig_symbols[skipfirst:symbols+skipfirst] orig_symbols = orig_symbols[skipfirst : symbols + skipfirst - skiplast]
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps)) timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps - skiplast * sps))
data *= np.sqrt(normalize) data *= np.sqrt(normalize)
launch_power = float(config["signal"]["laser_power"])
output_power = float(config["signal"]["edfa_power"])
target_normalization = 10 ** (output_power / 10) / 10 ** (launch_power / 10)
# target_normalization *= 0.5 # allow 50% power loss, so the network can ignore parts of the signal
data[:, 0:2] *= np.sqrt(target_normalization)
# if normalize: # if normalize:
# # square gets normalized to 1, as the power is (proportional to) the square of the amplitude # # square gets normalized to 1, as the power is (proportional to) the square of the amplitude
# a, b, c, d = data.T # a, b, c, d = data.T
@@ -132,13 +141,15 @@ class FiberRegenerationDataset(Dataset):
target_delay: float | int = 0, target_delay: float | int = 0,
xy_delay: float | int = 0, xy_delay: float | int = 0,
drop_first: float | int = 0, drop_first: float | int = 0,
drop_last=0,
dtype: torch.dtype = None, dtype: torch.dtype = None,
real: bool = False, real: bool = False,
device=None, device=None,
# osnr: float|None = None, # osnr: float|None = None,
polarisations = None, polarisations=None,
randomise_polarisations: bool = False, randomise_polarisations: bool = False,
repeat_randoms: int = 1, repeat_randoms: int = 1,
# cross_pol_interference: float = 0,
**kwargs, **kwargs,
): ):
""" """
@@ -172,6 +183,7 @@ class FiberRegenerationDataset(Dataset):
assert drop_first >= 0, "drop_first must be non-negative" assert drop_first >= 0, "drop_first must be non-negative"
self.randomise_polarisations = randomise_polarisations self.randomise_polarisations = randomise_polarisations
# self.cross_pol_interference = cross_pol_interference
data_raw = None data_raw = None
self.config = None self.config = None
@@ -181,6 +193,7 @@ class FiberRegenerationDataset(Dataset):
data, config, orig_syms = load_data( data, config, orig_syms = load_data(
file_path, file_path,
skipfirst=drop_first, skipfirst=drop_first,
skiplast=drop_last,
symbols=kwargs.get("num_symbols", None), symbols=kwargs.get("num_symbols", None),
real=real, real=real,
normalize=1000, normalize=1000,
@@ -300,20 +313,18 @@ class FiberRegenerationDataset(Dataset):
# fiber_out: [E_out_x, E_out_y, timestamps] # fiber_out: [E_out_x, E_out_y, timestamps]
# add noise related to amplification necessary due to splitting of the signal # add noise related to amplification necessary due to splitting of the signal
gain_lin = output_dim*2 # gain_lin = output_dim*2
edfa_nf = float(self.config["signal"]["edfa_nf"]) # gain_lin = 1
nf_lin = 10**(edfa_nf/10) # edfa_nf = float(self.config["signal"]["edfa_nf"])
f0 = float(self.config["glova"]["f0"]) # nf_lin = 10**(edfa_nf/10)
# f0 = float(self.config["glova"]["f0"])
noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9
noise = torch.randn_like(fiber_out[:2, :])
noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1))
noise = noise * torch.sqrt(noise_add / noise_power)
fiber_out[:2, :] += noise
# noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9
# noise = torch.randn_like(fiber_out[:2, :])
# noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1))
# noise = noise * torch.sqrt(noise_add / noise_power)
# fiber_out[:2, :] += noise
# if osnr is None: # if osnr is None:
# noisy = fiber_out[:2, :] # noisy = fiber_out[:2, :]
@@ -324,7 +335,6 @@ class FiberRegenerationDataset(Dataset):
# fiber_out: [E_out_x, E_out_y, timestamps, E_out_x_noisy, E_out_y_noisy] # fiber_out: [E_out_x, E_out_y, timestamps, E_out_x_noisy, E_out_y_noisy]
if repeat_randoms > 1: if repeat_randoms > 1:
fiber_in = fiber_in.repeat(1, 1, repeat_randoms) fiber_in = fiber_in.repeat(1, 1, repeat_randoms)
fiber_out = fiber_out.repeat(1, 1, repeat_randoms) fiber_out = fiber_out.repeat(1, 1, repeat_randoms)
@@ -334,8 +344,9 @@ class FiberRegenerationDataset(Dataset):
if self.randomise_polarisations: if self.randomise_polarisations:
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms, device=fiber_out.device), 2) * torch.pi angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms, device=fiber_out.device), 2) * torch.pi
# start_angle = torch.rand(1) * 2 * torch.pi start_angle = torch.rand(1) * 2 * torch.pi
# angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
angles = torch.randn(data_raw.shape[-1], device=fiber_out.device) * 2*torch.pi / 36 # sigma = 10 degrees
# self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi # self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi
else: else:
angles = torch.zeros(data_raw.shape[-1], device=fiber_out.device) angles = torch.zeros(data_raw.shape[-1], device=fiber_out.device)
@@ -361,8 +372,6 @@ class FiberRegenerationDataset(Dataset):
# 4 E_out_y_rot, # 4 E_out_y_rot,
# 5 angle # 5 angle
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1) # data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
# data layout # data layout
# [ [E_in_x, E_in_y, timestamps], # [ [E_in_x, E_in_y, timestamps],
@@ -374,6 +383,9 @@ class FiberRegenerationDataset(Dataset):
self.fiber_out = fiber_out.unfold(dimension=-1, size=self.samples_per_slice, step=1) self.fiber_out = fiber_out.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.fiber_out = self.fiber_out.movedim(-2, 0) self.fiber_out = self.fiber_out.movedim(-2, 0)
# if self.randomise_polarisations:
# self.angles = torch.cumsum((torch.rand(self.fiber_out.shape[0]) - 0.5) * 2 * torch.pi * 2 / 5000, dim=0)
# self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1) # self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
# self.data = self.data.movedim(-2, 0) # self.data = self.data.movedim(-2, 0)
# self.angles = torch.zeros(self.data.shape[0]) # self.angles = torch.zeros(self.data.shape[0])
@@ -392,12 +404,12 @@ class FiberRegenerationDataset(Dataset):
return self.fiber_in.shape[0] return self.fiber_in.shape[0]
def add_noise(self, data, osnr): def add_noise(self, data, osnr):
osnr_lin = 10**(osnr/10) osnr_lin = 10 ** (osnr / 10)
popt = torch.mean(data.abs().square().squeeze(), dim=-1) popt = torch.mean(data.abs().square().squeeze(), dim=-1)
noise = torch.randn_like(data) noise = torch.randn_like(data)
pn = torch.mean(noise.abs().square().squeeze(), dim=-1) pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
mult = torch.sqrt(popt/(pn*osnr_lin)) mult = torch.sqrt(popt / (pn * osnr_lin))
mult = mult * torch.eye(popt.shape[0], device=mult.device) mult = mult * torch.eye(popt.shape[0], device=mult.device)
mult = mult.to(dtype=noise.dtype) mult = mult.to(dtype=noise.dtype)
@@ -406,7 +418,6 @@ class FiberRegenerationDataset(Dataset):
noisy = data + noise noisy = data + noise
return noisy return noisy
def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(idx, slice): if isinstance(idx, slice):
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))] return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
@@ -418,6 +429,10 @@ class FiberRegenerationDataset(Dataset):
output_dim = self.output_dim // 2 output_dim = self.output_dim // 2
self.output_dim = output_dim * 2 self.output_dim = output_dim * 2
if not self.polarisations:
output_dim = 2 * output_dim
fiber_in = self.fiber_in[idx].squeeze() fiber_in = self.fiber_in[idx].squeeze()
fiber_out = self.fiber_out[idx].squeeze() fiber_out = self.fiber_out[idx].squeeze()
@@ -427,85 +442,35 @@ class FiberRegenerationDataset(Dataset):
fiber_in = fiber_in.view(fiber_in.shape[0], output_dim, -1) fiber_in = fiber_in.view(fiber_in.shape[0], output_dim, -1)
fiber_out = fiber_out.view(fiber_out.shape[0], output_dim, -1) fiber_out = fiber_out.view(fiber_out.shape[0], output_dim, -1)
center_angle = fiber_out[5, output_dim // 2, 0]
# data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
# data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
# angle = self.angles[idx]
# fiber_in:
# 0 E_in_x,
# 1 E_in_y,
# 2 timestamps
# fiber_out:
# 0 E_out_x,
# 1 E_out_y,
# 2 timestamps,
# 3 E_out_x_rot,
# 4 E_out_y_rot,
# 5 angle
center_angle = fiber_out[0, output_dim // 2, 0]
angles = fiber_out[5, :, 0] angles = fiber_out[5, :, 0]
plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone()
plot_data_rot = fiber_out[3:5, output_dim // 2, 0].detach().clone() plot_data_rot = fiber_out[3:5, output_dim // 2, 0].detach().clone()
data = fiber_out[0:2, :, 0] data = fiber_out[0:2, :, 0]
# fiber_out_plot_clean = fiber_out[:2, output_dim // 2, 0].detach().clone() plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone()
# if self.polarisations:
# rot = int(np.random.randint(2)*2-1)
# pol_flipped_data[0:1, :] = rot*data[0, :]
# pol_flipped_data[1, :] = rot*data[1, :]
# plot_data_rot[0] = rot*plot_data_rot[0]
# plot_data_rot[1] = rot*plot_data_rot[1]
# center_angle = center_angle + (rot - 1) * torch.pi/2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
# angles = angles + (rot - 1) * torch.pi/2
# if self.randomise_polarisations:
# data = data.mT
# c = torch.cos(angle).unsqueeze(-1)
# s = torch.sin(angle).unsqueeze(-1)
# rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1)
# data = torch.bmm(data.mT.unsqueeze(0), rot.to(dtype=data.dtype)).squeeze(-1)
...
# angle = torch.zeros_like(angle)
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter)
# angle_data = fiber_out[:2, :, :].reshape(2, -1).mean(dim=1).repeat(1, output_dim)
# angle_data2 = self.complex_max(fiber_out[:2, :, :].reshape(2, -1)).repeat(1, output_dim)
# sop = self.polarimeter(plot_data)
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
# angle = data_slice[1, 3, self.output_dim // 2, 0].real
target = fiber_in[:2, output_dim // 2, 0] target = fiber_in[:2, output_dim // 2, 0]
plot_target = fiber_in[:2, output_dim // 2, 0].detach().clone() plot_target = fiber_in[:2, output_dim // 2, 0].detach().clone()
target_timestamp = fiber_in[2, output_dim // 2, 0].real target_timestamp = fiber_in[2, output_dim // 2, 0].real
... ...
if self.polarisations: if self.polarisations:
rot = int(np.random.randint(2)*2-1) rot = int(np.random.randint(2) * 2 - 1)
data = rot*data data = rot * data
target = rot*target target = rot * target
plot_data_rot = rot*plot_data_rot plot_data_rot = rot * plot_data_rot
center_angle = center_angle + (rot - 1) * torch.pi/2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0 center_angle = center_angle + (rot - 1) * torch.pi / 2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
angles = angles + (rot - 1) * torch.pi/2 angles = angles + (rot - 1) * torch.pi / 2
pol_flipped_data = -data pol_flipped_data = -data
pol_flipped_target = -target pol_flipped_target = -target
# data_timestamps = data[-1,:].real
# data = data[:-1, :]
# target_timestamp = target[-1].real
# target = target[:-1]
# plot_data = plot_data[:-1]
# transpose to interleave the x and y data in the output tensor # transpose to interleave the x and y data in the output tensor
data = data.transpose(0, 1).flatten().squeeze() data = data.transpose(0, 1).flatten().squeeze()
data = data / torch.sqrt(torch.ones(1) * len(data)) # power loss due to splitting
pol_flipped_data = pol_flipped_data.transpose(0, 1).flatten().squeeze() pol_flipped_data = pol_flipped_data.transpose(0, 1).flatten().squeeze()
pol_flipped_data = pol_flipped_data / torch.sqrt(
torch.ones(1) * len(pol_flipped_data)
) # power loss due to splitting
# angle_data = angle_data.transpose(0, 1).flatten().squeeze() # angle_data = angle_data.transpose(0, 1).flatten().squeeze()
# angle_data2 = angle_data2.transpose(0,1).flatten().squeeze() # angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
center_angle = center_angle.flatten().squeeze() center_angle = center_angle.flatten().squeeze()
@@ -526,8 +491,8 @@ class FiberRegenerationDataset(Dataset):
"y": target, "y": target,
"y_flipped": pol_flipped_target, "y_flipped": pol_flipped_target,
"y_stacked": torch.cat([target, pol_flipped_target], dim=-1), "y_stacked": torch.cat([target, pol_flipped_target], dim=-1),
# "center_angle": center_angle, "center_angle": center_angle,
# "angles": angles, "angles": angles,
"mean_angle": angles.mean(), "mean_angle": angles.mean(),
# "sop": sop, # "sop": sop,
# "angle_data": angle_data, # "angle_data": angle_data,

View File

@@ -1,16 +1,23 @@
from datetime import datetime
import json
from pathlib import Path
from typing import Literal
import h5py
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap from matplotlib.colors import LinearSegmentedColormap
# from cmap import Colormap as cm
import numpy as np import numpy as np
from scipy.cluster.vq import kmeans2 from scipy.cluster.vq import kmeans2
import warnings import warnings
import multiprocessing import multiprocessing
from rich.traceback import install from rich.traceback import install
from rich import pretty
from rich import print
install() install()
pretty.install() # from rich import pretty
# from rich import print
# pretty.install()
def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1): def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
@@ -21,6 +28,7 @@ def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
xaxis = np.arange(0, len(signal)) / sps xaxis = np.arange(0, len(signal)) / sps
return np.vstack([xaxis, signal]) return np.vstack([xaxis, signal])
def create_symbol_sequence(n_symbols, skew=1): def create_symbol_sequence(n_symbols, skew=1):
np.random.seed(42) np.random.seed(42)
data = np.random.randint(0, 4, n_symbols) / 4 data = np.random.randint(0, 4, n_symbols) / 4
@@ -39,6 +47,14 @@ def generate_signal(data, sps):
signal = np.convolve(data_padded, wavelet) signal = np.convolve(data_padded, wavelet)
signal = np.cumsum(signal) signal = np.cumsum(signal)
signal = signal[sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2] signal = signal[sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
mi, ma = np.min(signal), np.max(signal)
signal = (signal - mi) / (ma - mi)
mod = 0.8
signal *= mod
signal += 1 - mod
return signal return signal
@@ -49,8 +65,8 @@ def normalization_with_noise(signal, noise=0):
signal += awgn signal += awgn
# min-max normalization # min-max normalization
signal = signal - np.min(signal) # signal = signal - np.min(signal)
signal = signal / np.max(signal) # signal = signal / np.max(signal)
return signal return signal
@@ -68,26 +84,132 @@ def generate_wavelet(sps, oversample=3):
class eye_diagram: class eye_diagram:
def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4, multithreaded=True): def __init__(
self,
data,
*,
channel_names=None,
horizontal_bins=256,
vertical_bins=1000,
n_levels=4,
multithreaded=True,
save_file_or_dir=None,
):
# data has shape [channels, 2, samples] # data has shape [channels, 2, samples]
# each sample has a timestamp and a value # each sample has a timestamp and a value
if data.ndim == 2: if data.ndim == 2:
data = data[np.newaxis, :, :] data = data[np.newaxis, :, :]
self.channel_names = channel_names
self.raw_data = data self.raw_data = data
self.channels = data.shape[0]
self.y_bins = np.zeros(1)
self.x_bins = np.zeros(1)
self.eye_data = np.zeros(1)
self.channel_names = channel_names
self.n_channels = data.shape[0]
self.n_levels = n_levels self.n_levels = n_levels
self.eye_stats = [{"success": False} for _ in range(self.channels)] self.eye_stats = [{"success": False} for _ in range(self.n_channels)]
self.horizontal_bins = horizontal_bins self.horizontal_bins = horizontal_bins
self.vertical_bins = vertical_bins self.vertical_bins = vertical_bins
self.multi_threaded = multithreaded self.multi_threaded = multithreaded
self.analysed = False
self.eye_built = False self.eye_built = False
def generate_eye_data(self): self.save_file = save_file_or_dir
def load_data(self, file=None):
file = self.save_file if file is None else file
if file is None:
raise FileNotFoundError("No file specified.")
self.save_file = str(file)
# self.file_or_dir = self.save_file
with h5py.File(file, "r") as infile:
self.y_bins = infile["y_bins"][:]
self.x_bins = infile["x_bins"][:]
self.eye_data = infile["eye_data"][:]
self.channel_names = infile.attrs["channel_names"]
self.n_channels = infile.attrs["n_channels"]
self.n_levels = infile.attrs["n_levels"]
self.eye_stats = infile.attrs["eye_stats"]
self.eye_stats = [json.loads(stat) for stat in self.eye_stats]
self.horizontal_bins = infile.attrs["horizontal_bins"]
self.vertical_bins = infile.attrs["vertical_bins"]
self.multi_threaded = infile.attrs["multithreaded"]
self.analysed = infile.attrs["analysed"]
self.eye_built = infile.attrs["eye_built"]
def save_data(self, file_or_dir=None):
file_or_dir = self.save_file if file_or_dir is None else file_or_dir
if file_or_dir is None:
file = Path(f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5")
elif Path(file_or_dir).is_dir():
file = Path(file_or_dir) / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5"
else:
file = Path(file_or_dir)
# file.parent.mkdir(parents=True, exist_ok=True)
self.save_file = str(file)
with h5py.File(file, "w") as outfile:
outfile.create_dataset("eye_data", data=self.eye_data)
outfile.create_dataset("y_bins", data=self.y_bins)
outfile.create_dataset("x_bins", data=self.x_bins)
outfile.attrs["channel_names"] = self.channel_names
outfile.attrs["n_channels"] = self.n_channels
outfile.attrs["n_levels"] = self.n_levels
self.eye_stats = eye_diagram.convert_arrays(self.eye_stats)
outfile.attrs["eye_stats"] = [json.dumps(stat) for stat in self.eye_stats]
outfile.attrs["horizontal_bins"] = self.horizontal_bins
outfile.attrs["vertical_bins"] = self.vertical_bins
outfile.attrs["multithreaded"] = self.multi_threaded
outfile.attrs["analysed"] = self.analysed
outfile.attrs["eye_built"] = self.eye_built
@staticmethod
def convert_arrays(input_object):
"""
convert ndarrays in (nested) dict to lists
"""
if isinstance(input_object, np.ndarray):
return input_object.tolist()
elif isinstance(input_object, list):
return [eye_diagram.convert_arrays(old) for old in input_object]
elif isinstance(input_object, tuple):
return tuple(eye_diagram.convert_arrays(old) for old in input_object)
elif isinstance(input_object, dict):
dict_out = {}
for key, value in input_object.items():
dict_out[key] = eye_diagram.convert_arrays(value)
return dict_out
return input_object
def generate_eye_data(
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
):
# modes:
# default: try to load eye data from file, if not found, generate and save
# load: try to load eye data from file, if not found, generate but don't save
# save: generate eye data and save
update_save = True
if mode == "load":
self.load_data(file_or_dir)
update_save = False
elif mode == "default":
try:
self.load_data(file_or_dir)
update_save = False
except (FileNotFoundError, IsADirectoryError):
pass
if not self.eye_built:
update_save = True
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False) self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
self.y_bins = np.zeros((self.channels, self.vertical_bins)) self.y_bins = np.zeros((self.n_channels, self.vertical_bins))
self.eye_data = np.zeros((self.channels, self.vertical_bins, self.horizontal_bins)) self.eye_data = np.zeros((self.n_channels, self.vertical_bins, self.horizontal_bins))
datas = [self.raw_data[i] for i in range(self.channels)] datas = [self.raw_data[i] for i in range(self.n_channels)]
if self.multi_threaded: if self.multi_threaded:
with multiprocessing.Pool() as pool: with multiprocessing.Pool() as pool:
results = pool.map(self.generate_eye_data_single, datas) results = pool.map(self.generate_eye_data_single, datas)
@@ -98,54 +220,112 @@ class eye_diagram:
self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data) self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
self.eye_built = True self.eye_built = True
if mode == "save" or (mode == "default" and update_save):
self.save_data(file_or_dir)
def generate_eye_data_single(self, data): def generate_eye_data_single(self, data):
eye_data = np.zeros((self.vertical_bins, self.horizontal_bins)) eye_data = np.zeros((self.vertical_bins, self.horizontal_bins))
data_min = np.min(data[1, :]) data_min = np.min(data[1, :])
data_max = np.max(data[1, :]) data_max = np.max(data[1, :])
# round down/up to 1 decimal
data_min = np.floor(data_min*10)/10
data_max = np.ceil(data_max*10)/10
# data_range = data_max - data_min
# data_min -= 0.1 * data_range
# data_max += 0.1 * data_range
# data_min = -0.05
# data_max += 0.05
# data[1,:] -= np.min(data[1, :])
# data[1,:] /= np.max(data[1, :])
# data_min = 0
# data_max = 1
y_bins = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False) y_bins = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
t_vals = data[0, :] % 2 t_vals = data[0, :] % 2 # + np.random.randn(*data[0, :].shape) * 1 / (512)
val_vals = data[1, :] val_vals = data[1, :] # + np.random.randn(*data[1, :].shape) * 1 / (320)
x_indices = np.digitize(t_vals, self.x_bins) - 1 x_indices = np.digitize(t_vals, self.x_bins) - 1
y_indices = np.digitize(val_vals, y_bins) - 1 y_indices = np.digitize(val_vals, y_bins) - 1
np.add.at(eye_data, (y_indices, x_indices), 1) np.add.at(eye_data, (y_indices, x_indices), 1)
return eye_data, y_bins return eye_data, y_bins
def plot(self, title="Eye Diagram", stats=True, all_stats=True, show=True): def plot(
self,
title="Eye Diagram",
stats=True,
all_stats=True,
show=True,
mode: Literal["default", "load", "save", "nosave"] = "default",
# save_images = False,
# image_dir = None,
# cmap=None,
):
if stats and not self.analysed:
self.analyse(mode=mode)
if not self.eye_built: if not self.eye_built:
self.generate_eye_data() self.generate_eye_data(mode=mode)
cmap = LinearSegmentedColormap.from_list( cmap = LinearSegmentedColormap.from_list(
"eyemap", "eyemap",
[(0, "white"), (0.1, "blue"), (0.2, "cyan"), (0.5, "green"), (0.8, "yellow"), (0.9, "red"), (1, "magenta")], [
(0, "#FFFFFF00"),
(0.1, "blue"),
(0.2, "cyan"),
(0.5, "green"),
(0.8, "yellow"),
(0.9, "red"),
(1, "magenta"),
],
) )
if self.channels % 2 == 0: # cmap = cm('google:turbo_r' if cmap is None else cmap)
# first = cmap(-1)
# cmap = cmap.to_mpl()
# cmap.set_under(first, alpha=0)
if self.n_channels % 2 == 0:
rows = 2 rows = 2
cols = self.channels // 2 cols = self.n_channels // 2
else: else:
cols = int(np.ceil(np.sqrt(self.channels))) cols = int(np.ceil(np.sqrt(self.n_channels)))
rows = int(np.ceil(self.channels / cols)) rows = int(np.ceil(self.n_channels / cols))
fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False) fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False)
fig.suptitle(title) fig.suptitle(title)
fig.tight_layout() fig.tight_layout()
ax = np.atleast_1d(ax).transpose().flatten() ax = np.atleast_1d(ax).transpose().flatten()
for i in range(self.channels): for i in range(self.n_channels):
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}") ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i + 1}")
if (i+1) % rows == 0: if (i + 1) % rows == 0:
ax[i].set_xlabel("Symbol") ax[i].set_xlabel("Symbol")
if i < rows: if i < rows:
ax[i].set_ylabel("Amplitude") ax[i].set_ylabel("Amplitude")
ax[i].grid() ax[i].grid()
ax[i].set_axisbelow(True)
ax[i].imshow( ax[i].imshow(
self.eye_data[i], self.eye_data[i] - 0.1,
origin="lower", origin="lower",
aspect="auto", aspect="auto",
cmap=cmap, cmap=cmap,
extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]], extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
interpolation="gaussian",
vmin=0,
zorder=3,
) )
ax[i].set_xlim((self.x_bins[0], self.x_bins[-1])) ax[i].set_xlim((self.x_bins[0], self.x_bins[-1]))
ymin = np.min(self.y_bins[:, 0]) ymin = np.min(self.y_bins[:, 0])
ymax = np.max(self.y_bins[:, -1]) ymax = np.max(self.y_bins[:, -1])
yspan = ymax - ymin yspan = ymax - ymin
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan)) ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
# if save_images:
# image_dir = "images_out" if image_dir is None else image_dir
# image_path = Path(image_dir) / (slugify(f"{datetime.now().strftime("%Y%m%d_%H%M%S")}_{title.replace(" ","_")}_{self.channel_names[i].replace(" ", "_") if self.channel_names is not None else f"{i + 1}"}_{ymin:.1f}_{ymax:.1f}") + ".png")
# image_path.parent.mkdir(parents=True, exist_ok=True)
# # plt.imsave(
# # image_path,
# # self.eye_data[i] - 0.1,
# # origin="lower",
# # # aspect="auto",
# # cmap=cmap,
# # # extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
# # # interpolation="gaussian",
# # vmin=0,
# # # zorder=3,
# # )
if stats and self.eye_stats[i]["success"]: if stats and self.eye_stats[i]["success"]:
# # add min_area above the plot # # add min_area above the plot
# ax[i].annotate( # ax[i].annotate(
@@ -159,7 +339,7 @@ class eye_diagram:
if all_stats: if all_stats:
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--") ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
y_ticks = (*self.eye_stats[i]["levels"],*self.eye_stats[i]["thresholds"]) y_ticks = (*self.eye_stats[i]["levels"], *self.eye_stats[i]["thresholds"])
# y_ticks = np.sort(y_ticks) # y_ticks = np.sort(y_ticks)
ax[i].set_yticks(y_ticks) ax[i].set_yticks(y_ticks)
# add arrows for amplitudes # add arrows for amplitudes
@@ -235,19 +415,19 @@ class eye_diagram:
def calculate_thresholds(levels): def calculate_thresholds(levels):
ret = np.cumsum(levels, dtype=float) ret = np.cumsum(levels, dtype=float)
ret[2:] = ret[2:] - ret[:-2] ret[2:] = ret[2:] - ret[:-2]
return ret[1:]/2 return ret[1:] / 2
def analyse_single(self, data, index): def analyse_single(self, data, index):
warnings.filterwarnings("error") warnings.filterwarnings("error")
eye_stats = {} eye_stats = {}
eye_stats["channel_name"] = str(index+1) if self.channel_names is None else self.channel_names[index] eye_stats["channel_name"] = str(index + 1) if self.channel_names is None else self.channel_names[index]
try: try:
approx_levels = eye_diagram.approximate_levels(data, self.n_levels) approx_levels = eye_diagram.approximate_levels(data, self.n_levels)
time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels) time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels)
eye_stats["time_midpoint_calc"] = (time_bounds[0] + time_bounds[1]) / 2 eye_stats["time_midpoint"] = float((time_bounds[0] + time_bounds[1]) / 2)
eye_stats["time_midpoint"] = 1.0 # eye_stats["time_midpoint"] = 1.0
eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels( eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels(
data, approx_levels, time_bounds data, approx_levels, time_bounds
@@ -257,9 +437,7 @@ class eye_diagram:
eye_stats["amplitudes"] = np.diff(eye_stats["levels"]) eye_stats["amplitudes"] = np.diff(eye_stats["levels"])
eye_stats["heights"] = eye_diagram.calculate_eye_heights( eye_stats["heights"] = eye_diagram.calculate_eye_heights(eye_stats["amplitude_clusters"])
eye_stats["amplitude_clusters"]
)
eye_stats["widths"], eye_stats["time_clusters"] = eye_diagram.calculate_eye_widths( eye_stats["widths"], eye_stats["time_clusters"] = eye_diagram.calculate_eye_widths(
data, eye_stats["levels"] data, eye_stats["levels"]
@@ -291,17 +469,39 @@ class eye_diagram:
warnings.resetwarnings() warnings.resetwarnings()
return eye_stats return eye_stats
def analyse(
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
):
# modes:
# default: try to load eye data from file, if not found, generate and save
# load: try to load eye data from file, if not found, generate but don't save
# save: generate eye data and save
update_save = True
if mode == "load":
self.load_data(file_or_dir)
update_save = False
elif mode == "default":
try:
self.load_data(file_or_dir)
update_save = False
except (FileNotFoundError, IsADirectoryError):
pass
def analyse(self): if not self.analysed:
update_save = True
self.eye_stats = [] self.eye_stats = []
if self.multi_threaded: if self.multi_threaded:
with multiprocessing.Pool() as pool: with multiprocessing.Pool() as pool:
results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.channels)]) results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.n_channels)])
for i, result in enumerate(results): for i, result in enumerate(results):
self.eye_stats.append(result) self.eye_stats.append(result)
else: else:
for i in range(self.channels): for i in range(self.n_channels):
self.eye_stats.append(self.analyse_single(self.raw_data[i], i)) self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
self.analysed = True
if mode == "save" or (mode == "default" and update_save):
self.save_data(file_or_dir)
@staticmethod @staticmethod
def approximate_levels(data, levels): def approximate_levels(data, levels):
@@ -443,7 +643,7 @@ class eye_diagram:
if __name__ == "__main__": if __name__ == "__main__":
length = int(2**14) length = int(2**16)
# data = generate_sample_data(length, noise=1) # data = generate_sample_data(length, noise=1)
# data1 = generate_sample_data(length, noise=0.01) # data1 = generate_sample_data(length, noise=0.01)
# data2 = generate_sample_data(length, noise=0.01, skew=1.2) # data2 = generate_sample_data(length, noise=0.01, skew=1.2)
@@ -451,13 +651,13 @@ if __name__ == "__main__":
# data = np.stack([data, data1, data2, data3]) # data = np.stack([data, data1, data2, data3])
data = generate_sample_data(length, noise=0.005) data = generate_sample_data(length, noise=0.0000)
eye = eye_diagram(data, horizontal_bins=256, vertical_bins=256) eye = eye_diagram(data, horizontal_bins=1056, vertical_bins=1200)
eye.analyse() eye.plot(mode="nosave", stats=False)
attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area") # attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area")
for i, channel in enumerate(eye.eye_stats): # for i, channel in enumerate(eye.eye_stats):
print(f"Channel {i}") # print(f"Channel {i}")
print_data = {attr: channel[attr] for attr in attrs} # print_data = {attr: channel[attr] for attr in attrs}
print(print_data) # print(print_data)
eye.plot() # eye.plot()

View File

@@ -0,0 +1,122 @@
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
# This software is licensed according to the "BSD 2-clause" license.
# modified by Joseph Hopfmüller in 2025,
# for integration into optical regeneration analysis scripts
from pathlib import Path
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.colors as colors
import numpy as _np
from .core import grid_count as _grid_count
import matplotlib.pyplot as _plt
import numpy as np
from scipy.ndimage import gaussian_filter
# from ._common import _common_doc
__all__ = ["eyediagram"] # , 'eyediagram_lines']
# def eyediagram_lines(y, window_size, offset=0, **plotkwargs):
# """
# Plot an eye diagram using matplotlib by repeatedly calling the `plot`
# function.
# <common>
# """
# start = offset
# while start < len(y):
# end = start + window_size
# if end > len(y):
# end = len(y)
# yy = y[start:end+1]
# _plt.plot(_np.arange(len(yy)), yy, 'k', **plotkwargs)
# start = end
# eyediagram_lines.__doc__ = eyediagram_lines.__doc__.replace("<common>",
# _common_doc)
eyemap = LinearSegmentedColormap.from_list(
"eyemap",
[
(0, "#0000FF00"),
(0.1, "blue"),
(0.2, "cyan"),
(0.5, "green"),
(0.8, "yellow"),
(0.9, "red"),
(1, "magenta"),
],
)
def eyediagram(
y,
window_size,
offset=0,
colorbar=False,
show=False,
save_im=False,
overwrite=False,
blur: int | bool = True,
save_path="out.png",
bounds=None,
**imshowkwargs,
):
"""
Plot an eye diagram using matplotlib by creating an image and calling
the `imshow` function.
<common>
"""
if bounds is None:
ymax = y.max()
ymin = y.min()
yamp = ymax - ymin
ymin = ymin - 0.05 * yamp
ymax = ymax + 0.05 * yamp
ymin = np.floor(ymin * 10) / 10
ymax = np.ceil(ymax * 10) / 10
bounds = (ymin, ymax)
counts = _grid_count(y, window_size, offset, bounds=bounds, size=(1000, 1200), blur=int(blur))
counts = counts.astype(_np.float32)
origin = imshowkwargs.pop("origin", "lower")
cmap: colors.Colormap = imshowkwargs.pop("cmap", eyemap)
vmin = imshowkwargs.pop("vmin", 1)
vmax = imshowkwargs.pop("vmax", None)
cmap.set_under("white", alpha=0)
if show:
_plt.imshow(
counts.T[::-1, :],
extent=[0, 2, *bounds],
origin=origin,
cmap=cmap,
vmin=vmin,
vmax=vmax,
**imshowkwargs,
)
_plt.grid()
if colorbar:
_plt.colorbar()
if Path(save_path).is_file() and not overwrite:
save_im = False
if save_im:
from PIL import Image
arr = counts.T[::-1, :]
if origin == "lower":
arr = arr[::-1]
arr = (arr-arr.min())/(arr.max()-arr.min())
image = Image.fromarray((cmap(arr)[:, :, :] * 255).astype(np.uint8))
image.save(save_path)
# print("-")
if show:
_plt.show()
# eyediagram.__doc__ = eyediagram.__doc__.replace("<common>", _common_doc)