model robustness testing

This commit is contained in:
Joseph Hopfmüller
2025-01-10 23:40:54 +01:00
parent 3af73343c1
commit f38d0ca3bb
13 changed files with 1558 additions and 334 deletions

1
.gitignore vendored
View File

@@ -163,3 +163,4 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
tolerance_results/datasets/*

37
notes/models.md Normal file
View File

@@ -0,0 +1,37 @@
# models
## no polarisation flipping
```py
config_path="data/20241229-163*-128-16384-50000-*.ini"
model=".models/best_20241230_011907.tar"
```
```py
config_path="data/20241229-163*-128-16384-80000-*.ini"
model=".models/best_20241230_103752.tar"
```
```py
config_path="data/20241229-163*-128-16384-100000-*.ini"
model=".models/best_20241230_164534.tar"
```
## with polarisation flipping
polarisation flipping: signal is randomly rotated by 180°. polarization rotation can be detected by adding a tone on one of the polarisations, but only to mod 180° with a direct detection setup. the randomly flipped signal should allow the network to hopefully learn to compensate for dispersion, pmd independently from the polarization rot. the training data includes the flipped signal as well, but no indication if the polarisation is flipped.
```py
config_path="data/20241229-163*-128-16384-50000-*.ini"
model=".models/best_20241231_000328.tar"
```
```py
config_path="data/20241229-163*-128-16384-80000-*.ini"
model=".models/best_20241231_163614.tar"
```
```py
config_path="data/20241229-163*-128-16384-100000-*.ini"
model=".models/best_20241231_170532.tar"
```

View File

@@ -124,7 +124,7 @@ class regenerator(Module):
parametrizations: list[dict] = None, parametrizations: list[dict] = None,
dtype=torch.float64, dtype=torch.float64,
dropout_prob=0.01, dropout_prob=0.01,
scale_layers=False, prescale=1,
rotate=False, rotate=False,
): ):
super(regenerator, self).__init__() super(regenerator, self).__init__()
@@ -134,15 +134,14 @@ class regenerator(Module):
act_func_kwargs = act_func_kwargs or {} act_func_kwargs = act_func_kwargs or {}
self.rotation = rotate self.rotation = rotate
self.prescale = prescale
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers) self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob)
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers): def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob):
for i in range(0, self._n_hidden_layers): for i in range(0, self._n_hidden_layers):
self.add_module(f"layer_{i}", Sequential()) self.add_module(f"layer_{i}", Sequential())
if scale_layers:
self.get_submodule(f"layer_{i}").add_module("scale", Scale(dims[i]))
module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs) module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("ONN", module) self.get_submodule(f"layer_{i}").add_module("ONN", module)
@@ -156,8 +155,8 @@ class regenerator(Module):
self.add_module(f"layer_{self._n_hidden_layers}", Sequential()) self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
if scale_layers: # if scale_layers:
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2])) # self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs) module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module) self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module)
@@ -200,6 +199,7 @@ class regenerator(Module):
return powers return powers
def forward(self, x, angle=None, pre_rot=False, trace_powers=False): def forward(self, x, angle=None, pre_rot=False, trace_powers=False):
x = x * self.prescale
powers = self._trace_powers(trace_powers, x) powers = self._trace_powers(trace_powers, x)
# x = self.layer_0(x) # x = self.layer_0(x)
# powers = self._trace_powers(trace_powers, x, powers) # powers = self._trace_powers(trace_powers, x, powers)

View File

@@ -683,7 +683,7 @@ class RegenerationTrainer:
def define_model(self, model_kwargs=None): def define_model(self, model_kwargs=None):
if self.resume: if self.resume:
model_kwargs = self.checkpoint_dict["model_kwargs"] model_kwargs = None
else: else:
model_kwargs = model_kwargs model_kwargs = model_kwargs
@@ -692,6 +692,14 @@ class RegenerationTrainer:
input_dim = 2 * self.data_settings.output_size input_dim = 2 * self.data_settings.output_size
# if self.data_settings.polarisations:
# input_dim *= 2
output_dim = self.model_settings.output_dim
# if self.data_settings.polarisations:
output_dim *= 2
dtype = getattr(torch, self.data_settings.dtype) dtype = getattr(torch, self.data_settings.dtype)
afunc = getattr(util.complexNN, self.model_settings.model_activation_func) afunc = getattr(util.complexNN, self.model_settings.model_activation_func)
@@ -703,7 +711,7 @@ class RegenerationTrainer:
hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)] hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)]
self.model_kwargs = { self.model_kwargs = {
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim), "dims": (input_dim, *hidden_dims, output_dim),
"layer_function": layer_func, "layer_function": layer_func,
"layer_func_kwargs": self.model_settings.model_layer_kwargs, "layer_func_kwargs": self.model_settings.model_layer_kwargs,
"act_function": afunc, "act_function": afunc,
@@ -711,7 +719,7 @@ class RegenerationTrainer:
"parametrizations": layer_parametrizations, "parametrizations": layer_parametrizations,
"dtype": dtype, "dtype": dtype,
"dropout_prob": self.model_settings.dropout_prob, "dropout_prob": self.model_settings.dropout_prob,
"scale_layers": self.model_settings.scale, "prescale": self.model_settings.scale,
} }
else: else:
self.model_kwargs = model_kwargs self.model_kwargs = model_kwargs
@@ -745,11 +753,12 @@ class RegenerationTrainer:
num_symbols = None num_symbols = None
config_path = self.data_settings.config_path config_path = self.data_settings.config_path
randomise_polarisations = self.data_settings.randomise_polarisations randomise_polarisations = self.data_settings.randomise_polarisations
polarisations = self.data_settings.polarisations
osnr = self.data_settings.osnr osnr = self.data_settings.osnr
if override is not None: if override is not None:
num_symbols = override.get("num_symbols", None) num_symbols = override.get("num_symbols", None)
config_path = override.get("config_path", config_path) config_path = override.get("config_path", config_path)
# polarisations = override.get("polarisations", polarisations) polarisations = override.get("polarisations", polarisations)
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations) randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
# get dataset # get dataset
dataset = FiberRegenerationDataset( dataset = FiberRegenerationDataset(
@@ -763,6 +772,7 @@ class RegenerationTrainer:
real=not dtype.is_complex, real=not dtype.is_complex,
num_symbols=num_symbols, num_symbols=num_symbols,
randomise_polarisations=randomise_polarisations, randomise_polarisations=randomise_polarisations,
polarisations=polarisations,
osnr = osnr, osnr = osnr,
) )
@@ -832,17 +842,19 @@ class RegenerationTrainer:
running_loss = 0.0 running_loss = 0.0
self.model.train() self.model.train()
loader_len = len(train_loader) loader_len = len(train_loader)
x_key = "x_stacked"# if self.data_settings.polarisations else "x"
y_key = "y_stacked"# if self.data_settings.polarisations else "y"
for batch_idx, batch in enumerate(train_loader): for batch_idx, batch in enumerate(train_loader):
x = batch["x"] x = batch[x_key]
y = batch["y"] y = batch[y_key]
angles = batch["mean_angle"] angle = batch["mean_angle"]
self.model.zero_grad(set_to_none=True) self.model.zero_grad(set_to_none=True)
x, y, angles = ( x, y, angle = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
angles.to(self.pytorch_settings.device), angle.to(self.pytorch_settings.device),
) )
y_pred = self.model(x, -angles) y_pred = self.model(x, -angle)
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True) loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
loss_value = loss.item() loss_value = loss.item()
loss.backward() loss.backward()
@@ -886,17 +898,19 @@ class RegenerationTrainer:
self.model.eval() self.model.eval()
running_error = 0 running_error = 0
x_key = "x_stacked"# if self.data_settings.polarisations else "x"
y_key = "y_stacked"# if self.data_settings.polarisations else "y"
with torch.no_grad(): with torch.no_grad():
for _, batch in enumerate(valid_loader): for _, batch in enumerate(valid_loader):
x = batch["x"] x = batch[x_key]
y = batch["y"] y = batch[y_key]
angles = batch["mean_angle"] angle = batch["mean_angle"]
x, y, angles = ( x, y, angle = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
angles.to(self.pytorch_settings.device), angle.to(self.pytorch_settings.device),
) )
y_pred = self.model(x, -angles) y_pred = self.model(x, -angle)
error = util.complexNN.complex_mse_loss(y_pred, y, power=True) error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
error_value = error.item() error_value = error.item()
running_error += error_value running_error += error_value
@@ -953,15 +967,17 @@ class RegenerationTrainer:
regen = [] regen = []
timestamps = [] timestamps = []
angles = [] angles = []
x_key = "x_stacked"# if self.data_settings.polarisations else "x"
y_key = "y_stacked"# if self.data_settings.polarisations else "y"
with torch.no_grad(): with torch.no_grad():
model = model.to(self.pytorch_settings.device) model = model.to(self.pytorch_settings.device)
for batch in loader: for batch in loader:
x = batch["x"] x = batch[x_key]
y = batch["y"] y = batch[y_key]
plot_target = batch["plot_target"] plot_target = batch["plot_target"]
angle = batch["mean_angle"] angle = batch["mean_angle"]
center_angle = batch["center_angle"] # center_angle = batch["center_angle"]
timestamp = batch["timestamp"] timestamp = batch["timestamp"]
plot_data = batch["plot_data"] plot_data = batch["plot_data"]
plot_data_rot = batch["plot_data_rot"] plot_data_rot = batch["plot_data_rot"]
@@ -971,14 +987,16 @@ class RegenerationTrainer:
angle.to(self.pytorch_settings.device), angle.to(self.pytorch_settings.device),
) )
if trace_powers: if trace_powers:
y_pred, powers = model(x, angle, True).cpu() y_pred, powers = model(x, -angle, True).cpu()
else: else:
y_pred = model(x, angle).cpu() y_pred = model(x, -angle).cpu()
# x = x.cpu() # x = x.cpu()
# y = y.cpu() # y = y.cpu()
# if self.data_settings.polarisations:
y_pred = y_pred[:, :2]
y_pred = y_pred.view(y_pred.shape[0], -1, 2) y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y_pred = y_pred[:, y_pred.shape[1]//2, :] y_pred = y_pred[:, y_pred.shape[1]//2, :]
y = y.view(y.shape[0], -1, 2) # y = y.view(y.shape[0], -1, 2)
# plot_data = plot_data.view(plot_data.shape[0], -1, 2) # plot_data = plot_data.view(plot_data.shape[0], -1, 2)
# c = torch.cos(-angle).cpu() # c = torch.cos(-angle).cpu()
# s = torch.sin(-angle).cpu() # s = torch.sin(-angle).cpu()
@@ -996,7 +1014,7 @@ class RegenerationTrainer:
fiber_in.append(plot_target.squeeze()) fiber_in.append(plot_target.squeeze())
regen.append(y_pred.squeeze()) regen.append(y_pred.squeeze())
timestamps.append(timestamp.squeeze()) timestamps.append(timestamp.squeeze())
angles.append(center_angle.squeeze()) angles.append(angle.squeeze())
fiber_out = torch.vstack(fiber_out).cpu() fiber_out = torch.vstack(fiber_out).cpu()
fiber_out_rot = torch.vstack(fiber_out_rot).cpu() fiber_out_rot = torch.vstack(fiber_out_rot).cpu()
@@ -1352,7 +1370,8 @@ class RegenerationTrainer:
"num_symbols": self.pytorch_settings.batchsize, "num_symbols": self.pytorch_settings.batchsize,
"config_path": config_path, "config_path": config_path,
"shuffle": False, "shuffle": False,
"polarisations": (np.random.rand(1) * np.pi * 2,), # "polarisations": (np.random.rand(1) * np.pi * 2,),
"polarisations": self.data_settings.polarisations,
"randomise_polarisation": self.data_settings.randomise_polarisations, "randomise_polarisation": self.data_settings.randomise_polarisations,
} }
) )
@@ -1366,7 +1385,7 @@ class RegenerationTrainer:
fiber_out_rot = fiber_out_rot.view(-1, 2) fiber_out_rot = fiber_out_rot.view(-1, 2)
angles = angles.view(-1, 1) angles = angles.view(-1, 1)
angles = angles.real angles = angles.real
angles = torch.fmod(angles, 2 * torch.pi) angles = torch.fmod(angles, 2*torch.pi)
angles = torch.div(angles, 2*torch.pi) angles = torch.div(angles, 2*torch.pi)
angles = torch.repeat_interleave(angles, 2, dim=1) angles = torch.repeat_interleave(angles, 2, dim=1)

View File

@@ -0,0 +1,253 @@
import os
from matplotlib import pyplot as plt
import numpy as np
import torch
import util
from hypertraining.settings import GlobalSettings, DataSettings, ModelSettings, OptimizerSettings, PytorchSettings
from hypertraining import models
# def move_to_location_in_size(array, location, size):
# array_x, array_y = array.shape
# location_x, location_y = location
# size_x, size_y = size
# left = location_x
# right = size_x - array_x - location_x
# top = location_y
# bottom = size_y - array_y - location_y
# return np.pad(
# array,
# (
# (left, right),
# (top, bottom),
# ),
# constant_values=(-np.inf, -np.inf),
# )
def pad_to_size(array, size):
if not hasattr(size, "__len__"):
size = (size, size)
left = (
(size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0
)
right = (
(size[0] - array.shape[0]) // 2 if size[0] is not None else 0
)
top = (
(size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0
)
bottom = (
(size[1] - array.shape[1]) // 2 if size[1] is not None else 0
)
array: np.ndarray = array
if array.ndim == 2:
return np.pad(
array,
(
(left, right),
(top, bottom),
),
constant_values=(np.nan, np.nan),
)
elif array.ndim == 3:
return np.pad(
array,
(
(left, right),
(top, bottom),
(0,0)
),
constant_values=(np.nan, np.nan),
)
def model_plot(model_path):
torch.serialization.add_safe_globals([
*util.complexNN.__all__,
GlobalSettings,
DataSettings,
ModelSettings,
OptimizerSettings,
PytorchSettings,
models.regenerator,
torch.nn.utils.parametrizations.orthogonal,
])
checkpoint_dict = torch.load(model_path, weights_only=True)
dims = checkpoint_dict["model_kwargs"].pop("dims")
model = models.regenerator(*dims, **checkpoint_dict["model_kwargs"])
model.load_state_dict(checkpoint_dict["model_state_dict"])
model_params = []
plots = []
max_size = np.max(dims)
# max_act_size = np.max(dims[1:])
angles = [None, None]
weights = [None, None]
for num, (layer_name, layer) in enumerate(model.named_children()):
# each layer contains an "ONN" layer and an "activation" layer
# activation layer is approximately the same for all layers and nodes -> rotation by 90 degrees
onn_weights = layer.ONN.weight.T
onn_weights = onn_weights.detach().cpu().numpy()
onn_values = np.abs(onn_weights).real
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
act = layer.activation
act_values = np.ones((act.size, 1))
act_values = np.nan * act_values
act_angles = act.phase.unsqueeze(-1).detach().cpu().numpy()
...
# act_phi_bias = torch.pi * act.V_bias / (act.V_pi + 1e-8)
# act_phi_gain = torch.pi * (act.alpha * act.gain * act.responsivity) / (act.V_pi + 1e-8)
# xs = (0.01, 0.1, 1)
# act_values = np.zeros((act.size, len(xs)*2))
# act_angles = np.zeros((act.size, len(xs)*2))
# act_values[:,:] = np.nan
# act_angles[:,:] = np.nan
# for xi, x in enumerate(xs):
# phi_intermediate = act_phi_gain * x**2 + act_phi_bias
# act_resulting_gain = (
# 1j
# * torch.sqrt(1-act.alpha)
# * torch.exp(-0.5j * phi_intermediate)
# * torch.cos(0.5 * phi_intermediate)
# * x
# )
# act_resulting_gain = act_resulting_gain.detach().cpu().numpy()
# act_values[:, xi*2] = np.abs(act_resulting_gain).real
# act_angles[:, xi*2] = np.mod(np.angle(act_resulting_gain), 2*np.pi).real
# if angles[0] is None or angles[0] > np.min(onn_angles.flatten()):
# angles[0] = np.min(onn_angles.flatten())
# if angles[1] is None or angles[1] < np.max(onn_angles.flatten()):
# angles[1] = np.max(onn_angles.flatten())
# if weights[0] is None or weights[0] > np.min(onn_weights.flatten()):
# weights[0] = np.min(onn_weights.flatten())
# if weights[1] is None or weights[1] < np.max(onn_weights.flatten()):
# weights[1] = np.max(onn_weights.flatten())
model_params.append({layer_name: onn_weights})
plots.append({layer_name: (num, onn_values, onn_angles, act_values, act_angles)})
# fig, axs = plt.subplots(3, len(model_params)*2-1, figsize=(20, 5))
for plot in plots:
layer_name, (num, onn_values, onn_angles, act_values, act_angles) = plot.popitem()
# for_plot[:, :, 0] = (for_plot[:, :, 0] - angles[0]) / (angles[1] - angles[0])
# for_plot[:, :, 1] = (for_plot[:, :, 1] - weights[0]) / (weights[1] - weights[0])
onn_values = np.ma.array(onn_values, mask=np.isnan(onn_values))
onn_values = onn_values - np.min(onn_values)
onn_values = onn_values / np.max(onn_values)
act_values = np.ma.array(act_values, mask=np.isnan(act_values))
act_values = act_values - np.min(act_values)
act_values = act_values / np.max(act_values)
onn_values = onn_values
onn_values = pad_to_size(onn_values, (max_size, None))
act_values = act_values
act_values = pad_to_size(act_values, (max_size, 3))
onn_angles = onn_angles / np.pi
onn_angles = pad_to_size(onn_angles, (max_size, None))
act_angles = act_angles / np.pi
act_angles = pad_to_size(act_angles, (max_size, 3))
# onn_angles = onn_angles - np.min(onn_angles)
# onn_angles = onn_angles / np.max(onn_angles)
# act_angles = act_angles - np.min(act_angles)
# act_angles = act_angles / np.max(act_angles)
if num == 0:
value_img = np.concatenate((onn_values, act_values), axis=1)
angle_img = np.concatenate((onn_angles, act_angles), axis=1)
else:
value_img = np.concatenate((value_img, onn_values, act_values), axis=1)
angle_img = np.concatenate((angle_img, onn_angles, act_angles), axis=1)
# -np.inf to np.nan
# value_img[value_img == -np.inf] = np.nan
# angle_img += move_to_location_in_size(onn_angles, ((max_size+3)*num, 0), img_overall_size)
# angle_img += move_to_location_in_size(act_angles, ((max_size+3)*(num+1) + 2, 0), img_overall_size)
from cmcrameri import cm
from matplotlib import colors as mcolors
alpha_map = mcolors.LinearSegmentedColormap(
'alphamap',
{
'red': [(0, 0, 0), (1, 0, 0)],
'green': [(0, 0, 0), (1, 0, 0)],
'blue': [(0, 0, 0), (1, 0, 0)],
'alpha': [
(0, 1, 1),
# (0.2, 0.2, 0.1),
(1, 0, 0)
]
}
)
alpha_map.set_bad(color="#AAAAAA")
fig, axs = plt.subplots(3, 1, sharex=True, sharey=True, figsize=(7, 8.5))
fig.tight_layout()
# masked_value_img = np.ma.masked_where(np.isnan(value_img), value_img)
masked_value_img = value_img
cmap = cm.batlowW
cmap.set_bad(color="#AAAAAA")
im_val = axs[0].imshow(masked_value_img, cmap=cmap)
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
cmap = cm.romaO
cmap.set_bad(color="#AAAAAA")
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap)
im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap)
im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map)
axs[0].axis("off")
axs[1].axis("off")
axs[2].axis("off")
axs[0].set_title("Values")
axs[1].set_title("Angles")
axs[2].set_title("Values and Angles")
...
plt.show()
# model = models.regenerator(*dims, **model_kwargs)
if __name__ == "__main__":
model_plot(".models/best_20250105_145719.tar")

View File

@@ -1,4 +1,4 @@
from datetime import datetime # from datetime import datetime
from pathlib import Path from pathlib import Path
import matplotlib import matplotlib
import numpy as np import numpy as np
@@ -13,7 +13,7 @@ from hypertraining.settings import (
OptimizerSettings, OptimizerSettings,
) )
from hypertraining.training import RegenerationTrainer, PolarizationTrainer from hypertraining.training import RegenerationTrainer#, PolarizationTrainer
# import torch # import torch
import json import json
@@ -27,7 +27,7 @@ global_settings = GlobalSettings(
data_settings = DataSettings( data_settings = DataSettings(
# config_path="data/*-128-16384-1-0-0-0-0-PAM4-0-0.ini", # config_path="data/*-128-16384-1-0-0-0-0-PAM4-0-0.ini",
config_path="data/*-128-16384-10000-0-0-17-0-PAM4-0.ini", config_path="data/20250110-190528-128-16384-100000-0-0.2-17.0-0.058-PAM4-0-0.14-10.ini",
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)], # config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
dtype="complex64", dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
@@ -37,12 +37,13 @@ data_settings = DataSettings(
shuffle=True, shuffle=True,
drop_first=64, drop_first=64,
train_split=0.8, train_split=0.8,
randomise_polarisations=True, randomise_polarisations=False,
osnr=10, polarisations=True,
osnr=16, #16dB due to amplification with NF 5
) )
pytorch_settings = PytorchSettings( pytorch_settings = PytorchSettings(
epochs=10000, epochs=1000,
batchsize=2**14, batchsize=2**14,
device="cuda", device="cuda",
dataloader_workers=24, dataloader_workers=24,
@@ -64,11 +65,11 @@ model_settings = ModelSettings(
# "n_hidden_nodes_3": 4, # "n_hidden_nodes_3": 4,
# "n_hidden_nodes_4": 2, # "n_hidden_nodes_4": 2,
}, },
model_activation_func="EOActivation", model_activation_func="phase_shift",
dropout_prob=0, dropout_prob=0,
model_layer_function="ONNRect", model_layer_function="ONNRect",
model_layer_kwargs={"square": True}, model_layer_kwargs={"square": True},
scale=False, scale=2.0,
model_layer_parametrizations=[ model_layer_parametrizations=[
{ {
"tensor_name": "weight", "tensor_name": "weight",
@@ -77,13 +78,17 @@ model_settings = ModelSettings(
{ {
"tensor_name": "alpha", "tensor_name": "alpha",
"parametrization": util.complexNN.clamp, "parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 1,
},
}, },
{ {
"tensor_name": "gain", "tensor_name": "gain",
"parametrization": util.complexNN.clamp, "parametrization": util.complexNN.clamp,
"kwargs": { "kwargs": {
"min": 0, "min": 0,
"max": float("inf"), "max": None,
}, },
}, },
{ {
@@ -95,8 +100,12 @@ model_settings = ModelSettings(
}, },
}, },
{ {
"tensor_name": "scales", "tensor_name": "scale",
"parametrization": util.complexNN.clamp, "parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2,
},
}, },
{ {
"tensor_name": "angle", "tensor_name": "angle",
@@ -244,9 +253,17 @@ if __name__ == "__main__":
pytorch_settings=pytorch_settings, pytorch_settings=pytorch_settings,
model_settings=model_settings, model_settings=model_settings,
optimizer_settings=optimizer_settings, optimizer_settings=optimizer_settings,
checkpoint_path=".models/best_20241216_221359.tar", # checkpoint_path=".models/best_20250104_191428.tar",
reset_epoch=True, reset_epoch=True,
# settings_override={ # settings_override={
# "data_settings": {
# "config_path": "data/20241229-163*-128-16384-100000-*.ini",
# "polarisations": True,
# },
# "model_settings": {
# "scale": 2.0,
# }
# }
# "optimizer_settings": { # "optimizer_settings": {
# "optimizer_kwargs": { # "optimizer_kwargs": {
# "lr": 0.01, # "lr": 0.01,

View File

@@ -16,16 +16,17 @@ from datetime import datetime
import hashlib import hashlib
from pathlib import Path from pathlib import Path
import time import time
import h5py
from matplotlib import pyplot as plt # noqa: F401 from matplotlib import pyplot as plt # noqa: F401
import numpy as np import numpy as np
import add_pypho # noqa: F401 from . import add_pypho # noqa: F401
import pypho import pypho
default_config = f""" default_config = f"""
[glova] [glova]
nos = 256 sps = 128
sps = 256 nos = 16384
f0 = 193414489032258.06 f0 = 193414489032258.06
symbolrate = 10e9 symbolrate = 10e9
wisdom_dir = "{str((Path.home() / ".pypho"))}" wisdom_dir = "{str((Path.home() / ".pypho"))}"
@@ -37,9 +38,9 @@ length = 10000
gamma = 1.14 gamma = 1.14
alpha = 0.2 alpha = 0.2
D = 17 D = 17
S = 0 S = 0.058
birefsteps = 0 bireflength = 10
max_delta_beta = 0.4 max_delta_beta = 0.14
; birefseed = 0xC0FFEE ; birefseed = 0xC0FFEE
[signal] [signal]
@@ -47,17 +48,15 @@ max_delta_beta = 0.4
modulation = "pam" modulation = "pam"
mod_order = 4 mod_order = 4
mod_depth = 0.8 mod_depth = 1
max_jitter = 0.02 max_jitter = 0.02
; jitter_seed = 0xC0FFEE ; jitter_seed = 0xC0FFEE
laser_power = 0 laser_power = 0
edfa_power = 3 edfa_power = 0
edfa_nf = 5 edfa_nf = 5
pulse_shape = "gauss" pulse_shape = "gauss"
fwhm = 0.33 fwhm = 0.33
osnr = "inf"
[data] [data]
dir = "data" dir = "data"
@@ -71,6 +70,7 @@ def get_config(config_file=None):
""" """
if config_file is None: if config_file is None:
config_file = Path(__file__).parent / "signal_generation.ini" config_file = Path(__file__).parent / "signal_generation.ini"
config_file = Path(config_file)
if not config_file.exists(): if not config_file.exists():
with open(config_file, "w") as f: with open(config_file, "w") as f:
f.write(default_config) f.write(default_config)
@@ -83,7 +83,10 @@ def get_config(config_file=None):
conf[section] = {} conf[section] = {}
for key in config[section]: for key in config[section]:
# print(f"{key} = {config[section][key]}") # print(f"{key} = {config[section][key]}")
try:
conf[section][key] = eval(config[section][key]) conf[section][key] = eval(config[section][key])
except NameError:
conf[section][key] = float(config[section][key])
# if isinstance(conf[section][key], str): # if isinstance(conf[section][key], str):
# conf[section][key] = config[section][key].strip('"') # conf[section][key] = config[section][key].strip('"')
return conf return conf
@@ -96,7 +99,9 @@ class PDM_IM_IPM:
mod_order=8, mod_order=8,
seed=None, seed=None,
): ):
assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, "mod_order must be a cube of an integer greater than 1" assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, (
"mod_order must be a cube of an integer greater than 1"
)
self.glova = glova self.glova = glova
self.mod_order = mod_order self.mod_order = mod_order
self.symbols_per_dim = int(np.cbrt(mod_order)) self.symbols_per_dim = int(np.cbrt(mod_order))
@@ -110,14 +115,7 @@ class PDM_IM_IPM:
class pam_generator: class pam_generator:
def __init__( def __init__(
self, self, glova, mod_order=None, mod_depth=0.5, pulse_shape="gauss", fwhm=0.33, seed=None, single_channel=False
glova,
mod_order=None,
mod_depth=0.5,
pulse_shape="gauss",
fwhm=0.33,
seed=None,
single_channel=False
) -> None: ) -> None:
self.glova = glova self.glova = glova
self.pulse_shape = pulse_shape self.pulse_shape = pulse_shape
@@ -138,9 +136,7 @@ class pam_generator:
symbols_x = symbols[0] / (self.mod_order) symbols_x = symbols[0] / (self.mod_order)
diffs_x = np.diff(symbols_x, prepend=symbols_x[0]) diffs_x = np.diff(symbols_x, prepend=symbols_x[0])
digital_x = self.generate_digital_signal(diffs_x, max_jitter) digital_x = self.generate_digital_signal(diffs_x, max_jitter)
digital_x = np.pad( digital_x = np.pad(digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)
)
# create analog signal of diff of symbols # create analog signal of diff of symbols
E_x = np.convolve(digital_x, wavelet) E_x = np.convolve(digital_x, wavelet)
@@ -158,16 +154,13 @@ class pam_generator:
symbols_y = symbols[1] / (self.mod_order) symbols_y = symbols[1] / (self.mod_order)
diffs_y = np.diff(symbols_y, prepend=symbols_y[0]) diffs_y = np.diff(symbols_y, prepend=symbols_y[0])
digital_y = self.generate_digital_signal(diffs_y, max_jitter) digital_y = self.generate_digital_signal(diffs_y, max_jitter)
digital_y = np.pad( digital_y = np.pad(digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)
)
E_y = np.convolve(digital_y, wavelet) E_y = np.convolve(digital_y, wavelet)
E_y = np.cumsum(E_y) * self.modulation_depth + (1 - self.modulation_depth) E_y = np.cumsum(E_y) * self.modulation_depth + (1 - self.modulation_depth)
E_y = E_y[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2] E_y = E_y[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
E[0]["E"][1] = np.sqrt(np.multiply(np.square(E[0]["E"][1]), E_y)) E[0]["E"][1] = np.sqrt(np.multiply(np.square(E[0]["E"][1]), E_y))
# rotate the signal on the y-polarisation by 90° # rotate the signal on the y-polarisation by 90°
@@ -176,7 +169,6 @@ class pam_generator:
E[0]["E"][1] = np.zeros_like(E[0]["E"][0], dtype=E[0]["E"][0].dtype) E[0]["E"][1] = np.zeros_like(E[0]["E"][0], dtype=E[0]["E"][0].dtype)
return E return E
def generate_digital_signal(self, symbols, max_jitter=0): def generate_digital_signal(self, symbols, max_jitter=0):
rs = np.random.RandomState(self.seed) rs = np.random.RandomState(self.seed)
signal = np.zeros(self.glova.nos * self.glova.sps) signal = np.zeros(self.glova.nos * self.glova.sps)
@@ -198,15 +190,11 @@ class pam_generator:
endpoint=True, endpoint=True,
) )
sigma = self.fwhm / (1 * np.sqrt(2 * np.log(2))) * self.glova.sps sigma = self.fwhm / (1 * np.sqrt(2 * np.log(2))) * self.glova.sps
pulse = ( pulse = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
1
/ (sigma * np.sqrt(2 * np.pi))
* np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
)
return pulse return pulse
def initialize_fiber_and_data(config, input_data_override=None): def initialize_fiber_and_data(config):
py_glova = pypho.setup( py_glova = pypho.setup(
nos=config["glova"]["nos"], nos=config["glova"]["nos"],
sps=config["glova"]["sps"], sps=config["glova"]["sps"],
@@ -221,22 +209,14 @@ def initialize_fiber_and_data(config, input_data_override=None):
c_data = pypho.cfiber.DataWrapper(py_glova.sps * py_glova.nos) c_data = pypho.cfiber.DataWrapper(py_glova.sps * py_glova.nos)
py_edfa = pypho.oamp(py_glova, Pmean=config["signal"]["edfa_power"], NF=config["signal"]["edfa_nf"]) py_edfa = pypho.oamp(py_glova, Pmean=config["signal"]["edfa_power"], NF=config["signal"]["edfa_nf"])
if input_data_override is not None: osnr = config["signal"]["osnr"] = float(config["signal"].get("osnr", "inf"))
c_data.E_in = input_data_override[0]
noise = input_data_override[1] config["signal"]["seed"] = config["signal"].get("seed", (int(time.time() * 1000)) % 2**32)
else: config["signal"]["jitter_seed"] = config["signal"].get("jitter_seed", (int(time.time() * 1000)) % 2**32)
config["signal"]["seed"] = config["signal"].get(
"seed", (int(time.time() * 1000)) % 2**32
)
config["signal"]["jitter_seed"] = config["signal"].get(
"jitter_seed", (int(time.time() * 1000)) % 2**32
)
symbolsrc = pypho.symbols( symbolsrc = pypho.symbols(
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"] py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
) )
laser = pypho.lasmod( laser = pypho.lasmod(py_glova, power=config["signal"]["laser_power"], Df=0, theta=np.pi / 4)
py_glova, power=config["signal"]["laser_power"]+1.5, Df=0, theta=np.pi / 4
)
modulator = pam_generator( modulator = pam_generator(
py_glova, py_glova,
mod_depth=config["signal"]["mod_depth"], mod_depth=config["signal"]["mod_depth"],
@@ -252,14 +232,28 @@ def initialize_fiber_and_data(config, input_data_override=None):
symbols_y[:3] = 0 symbols_y[:3] = 0
# symbols_x += 1 # symbols_x += 1
cw = laser() cw = laser()
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y)) source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
if osnr != float("inf"):
osnr_lin = 10 ** (osnr / 10)
signal_power = np.sum(pypho.functions.getpower_W(source_signal[0]["E"]))
noise_power = signal_power / osnr_lin
noise = np.random.normal(0, 1, source_signal[0]["E"].shape) + 1j * np.random.normal(
0, 1, source_signal[0]["E"].shape
)
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
noise = noise * np.sqrt(noise_power / noise_power_is)
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
source_signal[0]["E"] += noise
source_signal[0]["noise"] = noise_power_is
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))] # source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
source_signal = py_edfa(E=source_signal) nf = py_edfa.NF
source_signal = py_edfa(E=source_signal, NF=0)
py_edfa.NF = nf
c_data.E_in = source_signal[0]["E"] c_data.E_in = source_signal[0]["E"]
noise = source_signal[0]["noise"] noise = source_signal[0]["noise"]
@@ -273,25 +267,21 @@ def initialize_fiber_and_data(config, input_data_override=None):
S=config["fiber"]["s"], S=config["fiber"]["s"],
) )
if config["fiber"].get("birefsteps", 0) > 0: if config["fiber"].get("birefsteps", 0) > 0:
seed = config["fiber"].get( seed = config["fiber"].get("birefseed", (int(time.time() * 1000)) % 2**32)
"birefseed", (int(time.time() * 1000)) % 2**32
)
py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre( py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre(
py_fiber.l, py_fiber.l,
py_fiber.l / config["fiber"]["birefsteps"], py_fiber.l / config["fiber"]["birefsteps"],
# maxDeltaD=config["fiber"]["d"]/5, # maxDeltaD=config["fiber"]["d"]/5,
maxDeltaBeta = config["fiber"].get("max_delta_beta", 0), maxDeltaBeta=config["fiber"].get("max_delta_beta", 0),
seed=seed, seed=seed,
) )
c_params = pypho.cfiber.ParamsWrapper.from_fiber( c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200)
py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200
)
c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova) c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova)
return c_fiber, c_data, noise, py_edfa return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y)
def save_data(data, config): def save_data(data, config, **metadata):
data_dir = Path(config["data"]["dir"]) data_dir = Path(config["data"]["dir"])
npy_dir = config["data"].get("npy_dir", "") npy_dir = config["data"].get("npy_dir", "")
save_dir = data_dir / npy_dir if len(npy_dir) else data_dir save_dir = data_dir / npy_dir if len(npy_dir) else data_dir
@@ -306,6 +296,7 @@ def save_data(data, config):
seed = config["signal"].get("seed", False) seed = config["signal"].get("seed", False)
jitter_seed = config["signal"].get("jitter_seed", False) jitter_seed = config["signal"].get("jitter_seed", False)
birefseed = config["fiber"].get("birefseed", False) birefseed = config["fiber"].get("birefseed", False)
osnr = float(config["signal"].get("osnr", "inf"))
config_content = "\n".join(( config_content = "\n".join((
f"; Generated by {str(Path(__file__).name)} @ {timestamp.strftime('%Y-%m-%d %H:%M:%S')}", f"; Generated by {str(Path(__file__).name)} @ {timestamp.strftime('%Y-%m-%d %H:%M:%S')}",
@@ -317,14 +308,14 @@ def save_data(data, config):
f'wisdom_dir = "{config["glova"]["wisdom_dir"]}"', f'wisdom_dir = "{config["glova"]["wisdom_dir"]}"',
f'flags = "{config["glova"]["flags"]}"', f'flags = "{config["glova"]["flags"]}"',
f"nthreads = {config['glova']['nthreads']}", f"nthreads = {config['glova']['nthreads']}",
" ", "",
"[fiber]", "[fiber]",
f"length = {config['fiber']['length']}", f"length = {config['fiber']['length']}",
f"gamma = {config['fiber']['gamma']}", f"gamma = {config['fiber']['gamma']}",
f"alpha = {config['fiber']['alpha']}", f"alpha = {config['fiber']['alpha']}",
f"D = {config['fiber']['d']}", f"D = {config['fiber']['d']}",
f"S = {config['fiber']['s']}", f"S = {config['fiber']['s']}",
f"birefsteps = {config['fiber'].get('birefsteps',0)}", f"birefsteps = {config['fiber'].get('birefsteps', 0)}",
f"max_delta_beta = {config['fiber'].get('max_delta_beta', 0)}", f"max_delta_beta = {config['fiber'].get('max_delta_beta', 0)}",
f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set", f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set",
"", "",
@@ -334,49 +325,62 @@ def save_data(data, config):
f'modulation = "{config["signal"]["modulation"]}"', f'modulation = "{config["signal"]["modulation"]}"',
f"mod_order = {config['signal']['mod_order']}", f"mod_order = {config['signal']['mod_order']}",
f"mod_depth = {config['signal']['mod_depth']}", f"mod_depth = {config['signal']['mod_depth']}",
"" "",
f"max_jitter = {config['signal']['max_jitter']}", f"max_jitter = {config['signal']['max_jitter']}",
f"jitter_seed = {hex(jitter_seed)}" if jitter_seed else "; jitter_seed = not set", f"jitter_seed = {hex(jitter_seed)}" if jitter_seed else "; jitter_seed = not set",
"" "",
f"laser_power = {config['signal']['laser_power']}", f"laser_power = {config['signal']['laser_power']}",
f"edfa_power = {config['signal']['edfa_power']}", f"edfa_power = {config['signal']['edfa_power']}",
f"edfa_nf = {config['signal']['edfa_nf']}", f"edfa_nf = {config['signal']['edfa_nf']}",
"" f"osnr = {osnr}",
"",
f'pulse_shape = "{config["signal"]["pulse_shape"]}"', f'pulse_shape = "{config["signal"]["pulse_shape"]}"',
f"fwhm = {config['signal']['fwhm']}", f"fwhm = {config['signal']['fwhm']}",
"", "",
"[data]", "[data]",
f'dir = "{str(data_dir)}"', f'dir = "{str(data_dir)}"',
f'npy_dir = "{npy_dir}"', f'npy_dir = "{npy_dir}"',
"file = " "file = ",
)) ))
config_hash = hashlib.md5(config_content.encode()).hexdigest() config_hash = hashlib.md5(config_content.encode()).hexdigest()
save_file = f"{config_hash}.npy" save_file = f"{config_hash}.h5"
config_content += f'"{str(save_file)}"\n' config_content += f'"{str(save_file)}"\n'
filename_components = ( filename_components = (
timestamp.strftime("%Y%m%d-%H%M%S"), timestamp.strftime("%Y%m%d-%H%M%S"),
config["glova"]["sps"], config["glova"]["sps"],
config["glova"]["nos"], config["glova"]["nos"],
config["signal"]["osnr"],
config["fiber"]["length"], config["fiber"]["length"],
config["fiber"]["gamma"], config["fiber"]["gamma"],
config["fiber"]["alpha"], config["fiber"]["alpha"],
config["fiber"]["d"], config["fiber"]["d"],
config["fiber"]["s"], config["fiber"]["s"],
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}", f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
config['fiber'].get('birefsteps',0), config["fiber"].get("birefsteps", 0),
config["fiber"].get("max_delta_beta", 0), config["fiber"].get("max_delta_beta", 0),
int(config["glova"]["symbolrate"] / 1e9),
) )
lookup_file = "-".join(map(str, filename_components)) + ".ini" lookup_file = "-".join(map(str, filename_components)) + ".ini"
with open(data_dir / lookup_file, "w") as f: config_filename = data_dir / lookup_file
with open(config_filename, "w") as f:
f.write(config_content) f.write(config_content)
np.save(save_dir / save_file, save_data) with h5py.File(save_dir / save_file, "w") as outfile:
outfile.create_dataset("data", data=save_data)
outfile.create_dataset("symbols", data=metadata.pop("symbols"))
for key, value in metadata.items():
# if isinstance(value, dict):
# value = json.dumps(model_runner.convert_arrays(value))
outfile.attrs[key] = value
# np.save(save_dir / save_file, save_data)
print("Saved config to", data_dir / lookup_file) print("Saved config to", config_filename)
print("Saved data to", save_dir / save_file) print("Saved data to", save_dir / save_file)
return config_filename
def length_loop(config, lengths, save=True): def length_loop(config, lengths, save=True):
lengths = sorted(lengths) lengths = sorted(lengths)
@@ -386,23 +390,19 @@ def length_loop(config, lengths, save=True):
cfiber, cdata, noise, edfa = initialize_fiber_and_data(config) cfiber, cdata, noise, edfa = initialize_fiber_and_data(config)
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in)) mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
cfiber() cfiber()
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out)) mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)")
print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)")
print( E_tmp = [{"E": cdata.E_out, "noise": noise * (-cfiber.params.l * cfiber.params.alpha)}]
f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)"
)
print(
f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)"
)
E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}]
E_tmp = edfa(E=E_tmp) E_tmp = edfa(E=E_tmp)
cdata.E_out = E_tmp[0]['E'] cdata.E_out = E_tmp[0]["E"]
mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out))
print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)")
if save: if save:
save_data(cdata, config) save_data(cdata, config)
@@ -411,27 +411,57 @@ def length_loop(config, lengths, save=True):
def single_run_with_plot(config, save=True): def single_run_with_plot(config, save=True):
cfiber, cdata, noise, edfa = initialize_fiber_and_data(config) cfiber, cdata, config_filename = single_run(config, save)
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in)) in_out_eyes(cfiber, cdata, show_pols=False)
print( return config_filename
f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)"
) def single_run(config, save=True):
cfiber, cdata, noise, edfa, symbols = initialize_fiber_and_data(config)
# mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
# print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)")
# estimate osnr
# noise_power = np.mean(noise)
# osnr_lin = mean_power_in / noise_power - 1
# osnr = 10 * np.log10(osnr_lin)
# print(f"Estimated OSNR: {osnr:.3f} dB")
cfiber() cfiber()
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out)) # mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
print( # print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)")
f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)"
)
E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}] # noise = noise * np.exp(-cfiber.params.l * cfiber.params.alpha)
# estimate osnr
# noise_power = np.mean(noise)
# osnr_lin = mean_power_out / noise_power - 1
# osnr = 10 * np.log10(osnr_lin)
# print(f"Estimated OSNR: {osnr:.3f} dB")
E_tmp = [{"E": cdata.E_out, "noise": noise}]
E_tmp = edfa(E=E_tmp) E_tmp = edfa(E=E_tmp)
cdata.E_out = E_tmp[0]['E'] cdata.E_out = E_tmp[0]["E"]
if save: # noise = E_tmp[0]["noise"]
save_data(cdata, config)
# mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out))
# print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)")
# estimate osnr
# noise_power = np.mean(noise)
# osnr_lin = mean_power_amp / noise_power - 1
# osnr = 10 * np.log10(osnr_lin)
# print(f"Estimated OSNR: {osnr:.3f} dB")
config_filename = None
symbols = np.array(symbols)
if save:
config_filename = save_data(cdata, config, **{"symbols": symbols})
return cfiber,cdata,config_filename
in_out_eyes(cfiber, cdata, show_pols=False)
def in_out_eyes(cfiber, cdata, show_pols=False): def in_out_eyes(cfiber, cdata, show_pols=False):
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True) fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
@@ -595,9 +625,7 @@ def plot_eye_diagram(
signal = signal[: head * eye_width] signal = signal[: head * eye_width]
if normalize: if normalize:
signal = signal / np.max(signal) signal = signal / np.max(signal)
slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[ slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[offset % (eye_width + 1) :: eye_width]
offset % (eye_width + 1) :: eye_width
]
plt_ax = np.arange(-eye_width // 2, eye_width // 2 + 1) / samplerate plt_ax = np.arange(-eye_width // 2, eye_width // 2 + 1) / samplerate
for slice in slices: for slice in slices:
ax.plot(plt_ax, slice, color=color, alpha=0.1) ax.plot(plt_ax, slice, color=color, alpha=0.1)
@@ -618,14 +646,26 @@ if __name__ == "__main__":
# lengths = [*lengths, *lengths] # lengths = [*lengths, *lengths]
lengths = ( lengths = (
# 8000, 9000, # 8000, 9000,
10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000, 10000,
95000, 100000, 105000, 110000, 115000, 120000 20000,
30000,
40000,
50000,
60000,
70000,
80000,
90000,
95000,
100000,
105000,
110000,
115000,
120000,
) )
# lengths = (10000,100000) # lengths = (10000,100000)
length_loop(config, lengths, save=True) # length_loop(config, lengths, save=True)
# birefringence is constant over coupling length -> several 100m -> bireflength=1000 (m) # birefringence is constant over coupling length -> several 100m -> bireflength=1000 (m)
# single_run_with_plot(config, save=False) single_run_with_plot(config, save=False)

View File

@@ -0,0 +1,723 @@
"""
tests a given model for tolerance against variations in
- fiber length
- baudrate
- OSNR
CD, PMD, baudrate need different datasets, osnr is modeled as awgn added to the data before feeding into the model
"""
from datetime import datetime
from typing import Literal
from matplotlib import pyplot as plt
import numpy as np
from pathlib import Path
import h5py
import torch
import util
from hypertraining.settings import GlobalSettings, DataSettings, ModelSettings, OptimizerSettings, PytorchSettings
from hypertraining import models
from signal_gen.generate_signal import single_run, get_config
import json
class NestedParameterIterator:
def __init__(self, parameters):
"""
parameters: dict with key <param_name> and value <dict with keys "config" and "range">
"""
# self.parameters = parameters
self.names = []
self.ranges = []
self.configs = []
for k, v in parameters.items():
self.names.append(k)
self.ranges.append(v["range"])
self.configs.append(v["config"])
self.n_parameters = len(self.ranges)
self.idx = 0
self.range_idx = [0] * self.n_parameters
self.range_len = [len(r) for r in self.ranges]
self.length = int(np.prod(self.range_len))
self.out = []
for i in range(self.length):
self.out.append([])
for j in range(self.n_parameters):
element = {self.names[j]: {"value": self.ranges[j][self.range_idx[j]], "config": self.configs[j]}}
self.out[i].append(element)
self.range_idx[-1] += 1
# update range_idx back to front
for j in range(self.n_parameters - 1, -1, -1):
if self.range_idx[j] == self.range_len[j]:
self.range_idx[j] = 0
self.range_idx[j - 1] += 1
...
def __next__(self):
if self.idx == self.length:
raise StopIteration
self.idx += 1
return self.out[self.idx - 1]
def __iter__(self):
return self
class model_runner:
def __init__(
self,
# length_range: tuple[int | float] = (50e3, 50e3),
# length_steps: int = 1,
# length_log: bool = False,
# baudrate_range: tuple[int | float] = (10e9, 10e9),
# baudrate_steps: int = 1,
# baudrate_log: bool = False,
# osnr_range: tuple[int | float] = (16, 16),
# osnr_steps: int = 1,
# osnr_log: bool = False,
# dataset_dir: str = "data",
# dataset_datetime_glob: str = "*",
results_dir: str = "tolerance_results/datasets",
# model_dir: str = ".models",
config: str = "signal_generation.ini",
config_dir: str = None,
debug: bool = False,
):
"""
length_range: lower and upper limit of length, in meters
length_step: step size of length, in meters
baudrate_range: lower and upper limit of baudrate, in Bd
baudrate_step: step size of baudrate, in Bd
osnr_range: lower and upper limit of osnr, in dB
osnr_step: step size of osnr, in dB
dataset_dir: directory containing datasets
dataset_datetime_glob: datetime glob pattern for dataset files
results_dir: directory to save results
model_dir: directory containing models
"""
self.debug = debug
self.parameters = {}
self.iter = None
# self.update_length_range(length_range, length_steps, length_log)
# self.update_baudrate_range(baudrate_range, baudrate_steps, baudrate_log)
# self.update_osnr_range(osnr_range, osnr_steps, osnr_log)
# self.data_dir = Path(dataset_dir)
# self.data_datetime_glob = dataset_datetime_glob
self.results_dir = Path(results_dir)
# self.model_dir = Path(model_dir)
config_dir = config_dir or Path(__file__).parent
self.config = config_dir / config
torch.serialization.add_safe_globals([
*util.complexNN.__all__,
GlobalSettings,
DataSettings,
ModelSettings,
OptimizerSettings,
PytorchSettings,
models.regenerator,
torch.nn.utils.parametrizations.orthogonal,
])
self.load_model()
self.datasets = []
# def register_parameter(self, name, config):
# self.parameters.append({"name": name, "config": config})
def load_results_from_file(self, path):
data, meta = self.load_from_file(path)
self.results = [d.decode() for d in data]
self.parameters = meta["parameters"]
...
def load_datasets_from_file(self, path):
data, meta = self.load_from_file(path)
self.datasets = [d.decode() for d in data]
self.parameters = meta["parameters"]
...
def update_parameter_range(self, name, config, range, steps, log):
self.parameters[name] = {"config": config, "range": self.update_range(*range, steps, log)}
def generate_iterations(self):
if len(self.parameters) == 0:
raise ValueError("No parameters registered")
self.iter = NestedParameterIterator(self.parameters)
def generate_datasets(self):
# get base config
config = get_config(self.config)
if self.iter is None:
self.generate_iterations()
for params in self.iter:
current_settings = []
# params is a list of dictionaries with keys "name", containing a dict with keys "value", "config"
for param in params:
for name, settings in param.items():
current_settings.append({name: settings["value"]})
self.nested_set(config, settings["config"], settings["value"])
settings_strs = []
for setting in current_settings:
name = list(setting)[0]
settings_strs.append(f"{name}: {float(setting[name]):.2e}")
settings_str = ", ".join(settings_strs)
print(f"Generating dataset for [{settings_str}]")
# TODO look for existing datasets
_, _, path = single_run(config)
self.datasets.append(str(path))
datasets_list_path = self.build_path("datasets_list", parent_dir=self.results_dir, timestamp="back")
metadata = {"parameters": self.parameters}
data = np.array(self.datasets, dtype="S")
self.save_to_file(datasets_list_path, data, **metadata)
@staticmethod
def nested_set(dic, keys, value):
for key in keys[:-1]:
dic = dic.setdefault(key, {})
dic[keys[-1]] = value
## Dataset and model loading
# def find_datasets(self, data_dir=None, data_datetime_glob=None):
# # date-time-sps-nos-length-gamma-alpha-D-S-PAM4-birefsteps-deltabeta-symbolrate.ini
# data_dir = data_dir or self.data_dir
# data_datetime_glob = data_datetime_glob or self.data_datetime_glob
# self.datasets = {}
# data_dir = Path(data_dir)
# for length in self.lengths:
# for baudrate in self.baudrates:
# # dataset_glob = self.data_datetime_glob + f"*-*-{int(length)}-*-*-*-*-PAM4-*-*-{int(baudrate/1e9)}.ini"
# dataset_glob = data_datetime_glob + f"-*-*-{int(length)}-*-*-*-*-PAM4-*-*.ini"
# datasets = [f for f in data_dir.glob(dataset_glob)]
# if len(datasets) == 0:
# continue
# self.datasets[length] = {}
# if len(datasets) > 1:
# print(
# f"multiple datasets found for [{length / 1000:.1f} km, {int(baudrate / 1e9)} GBd]. Using the newest dataset."
# )
# # get newest file from creation date
# datasets.sort(key=lambda x: x.stat().st_ctime)
# self.datasets[length][baudrate] = str(datasets[-1])
def load_dataset(self, dataset_path):
if self.checkpoint_dict is None:
raise ValueError("Model must be loaded before dataset")
if self.dataset_path is None:
self.dataset_path = dataset_path
elif self.dataset_path == dataset_path:
return
symbols = self.checkpoint_dict["settings"]["data_settings"].symbols
data_size = self.checkpoint_dict["settings"]["data_settings"].output_size
dtype = getattr(torch, self.checkpoint_dict["settings"]["data_settings"].dtype)
drop_first = self.checkpoint_dict["settings"]["data_settings"].drop_first
randomise_polarisations = self.checkpoint_dict["settings"]["data_settings"].randomise_polarisations
polarisations = self.checkpoint_dict["settings"]["data_settings"].polarisations
num_symbols = None
if self.debug:
num_symbols = 1000
config_path = Path(dataset_path)
dataset = util.datasets.FiberRegenerationDataset(
file_path=config_path,
symbols=symbols,
output_dim=data_size,
drop_first=drop_first,
dtype=dtype,
real=not dtype.is_complex,
randomise_polarisations=randomise_polarisations,
polarisations=polarisations,
num_symbols=num_symbols,
# device="cuda" if torch.cuda.is_available() else "cpu",
)
self.dataloader = torch.utils.data.DataLoader(
dataset, batch_size=2**14, pin_memory=True, num_workers=24, prefetch_factor=8, shuffle=False
)
return self.dataloader.dataset.orig_symbols
# run model
# return results as array: [fiber_in, fiber_out, fiber_out_noisy, regen_out]
def load_model(self, model_path: str | None = None):
if model_path is None:
self.model = None
self.model_path = None
self.checkpoint_dict = None
return
path = Path(model_path)
if self.model_path is None:
self.model_path = path
elif path == self.model_path:
return
self.dataset_path = None # reset dataset path, as the shape depends on the model
self.checkpoint_dict = torch.load(path, weights_only=True)
dims = self.checkpoint_dict["model_kwargs"].pop("dims")
self.model = models.regenerator(*dims, **self.checkpoint_dict["model_kwargs"])
self.model.load_state_dict(self.checkpoint_dict["model_state_dict"])
## Model evaluation
def run_model_evaluation(self, model_path: str, datasets: str | None = None):
self.load_model(model_path)
# iterate over datasets and osnr values:
# load dataset, add noise, run model, return results
# save results to file
self.results = []
if datasets is not None:
self.load_datasets_from_file(datasets)
n_datasets = len(self.datasets)
for i, dataset_path in enumerate(self.datasets):
conf = get_config(dataset_path)
mpath = Path(model_path)
model_base = mpath.stem
print(f"({1+i}/{n_datasets}) Running model {model_base} with dataset {dataset_path.split('/')[-1]}")
results_path = self.build_path(
dataset_path.split("/")[-1], parent_dir=Path(self.results_dir) / model_base
)
orig_symbols = self.load_dataset(dataset_path)
data, loss = self.run_model()
metadata = {
"model_path": model_path,
"dataset_path": dataset_path,
"loss": loss,
"sps": conf["glova"]["sps"],
"orig_symbols": orig_symbols
# "config": conf,
# "checkpoint_dict": self.checkpoint_dict,
# "nos": self.dataloader.dataset.num_symbols,
}
self.save_to_file(results_path, data, **metadata)
self.results.append(str(results_path))
results_list_path = self.build_path("results_list", parent_dir=self.results_dir, timestamp="back")
metadata = {"parameters": self.parameters}
data = np.array(self.results, dtype="S")
self.save_to_file(results_list_path, data, **metadata)
def run_model(self):
loss = 0
datas = []
self.model.eval()
model = self.model.to("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
for batch in self.dataloader:
x = batch["x_stacked"]
y = batch["y_stacked"]
fiber_in = batch["plot_target"]
# fiber_out = batch["plot_clean"]
fiber_out = batch["plot_data"]
timestamp = batch["timestamp"]
angle = batch["mean_angle"]
x = x.to("cuda" if torch.cuda.is_available() else "cpu")
angle = angle.to("cuda" if torch.cuda.is_available() else "cpu")
regen = model(x, -angle)
regen = regen.to("cpu")
loss += util.complexNN.complex_mse_loss(regen, y, power=True).item()
# shape: [batch_size, 4]
plot_regen = regen[:, :2]
plot_regen = plot_regen.view(plot_regen.shape[0], -1, 2)
plot_regen = plot_regen[:, plot_regen.shape[1] // 2, :]
data_out = torch.cat(
(
fiber_in,
fiber_out,
# fiber_out_noisy,
plot_regen,
timestamp.view(-1, 1),
),
dim=1,
)
datas.append(data_out)
data_out = torch.cat(datas, dim=0).numpy()
return data_out, loss
## File I/O
@staticmethod
def save_to_file(path: str, data: np.ndarray, **metadata: dict):
# create directory if it doesn't exist
path.parent.mkdir(parents=True, exist_ok=True)
with h5py.File(path, "w") as outfile:
outfile.create_dataset("data", data=data)
for key, value in metadata.items():
if isinstance(value, dict):
value = json.dumps(model_runner.convert_arrays(value))
outfile.attrs[key] = value
@staticmethod
def convert_arrays(dict_in):
"""
convert ndarrays in (nested) dict to lists
"""
dict_out = {}
for key, value in dict_in.items():
if isinstance(value, dict):
dict_out[key] = model_runner.convert_arrays(value)
elif isinstance(value, np.ndarray):
dict_out[key] = value.tolist()
else:
dict_out[key] = value
return dict_out
@staticmethod
def load_from_file(path: str):
with h5py.File(path, "r") as infile:
data = infile["data"][:]
metadata = {}
for key in infile.attrs.keys():
if isinstance(infile.attrs[key], (str, bytes, bytearray)):
try:
metadata[key] = json.loads(infile.attrs[key])
except json.JSONDecodeError:
metadata[key] = infile.attrs[key]
else:
metadata[key] = infile.attrs[key]
return data, metadata
## Utility functions
@staticmethod
def logrange(start, stop, num, endpoint=False):
lower, upper = np.log10((start, stop))
return np.logspace(lower, upper, num=num, endpoint=endpoint, base=10)
@staticmethod
def build_path(
*elements, parent_dir: str | Path | None = None, filetype="h5", timestamp: Literal["no", "front", "back"] = "no"
):
suffix = f".{filetype}" if not filetype.startswith(".") else filetype
if timestamp != "no":
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
if timestamp == "front":
elements = (ts, *elements)
elif timestamp == "back":
elements = (*elements, ts)
path = "_".join(elements)
path += suffix
if parent_dir is not None:
path = Path(parent_dir) / path
return path
@staticmethod
def update_range(min, max, n_steps, log):
if log:
range = model_runner.logrange(min, max, n_steps, endpoint=True)
else:
range = np.linspace(min, max, n_steps, endpoint=True)
return range
class model_evaluation_result:
def __init__(
self,
*,
length=None,
baudrate=None,
osnr=None,
model_path=None,
dataset_path=None,
loss=None,
sps=None,
**kwargs,
):
self.length = length
self.baudrate = baudrate
self.osnr = osnr
self.model_path = model_path
self.dataset_path = dataset_path
self.loss = loss
self.sps = sps
self.sers = None
self.bers = None
self.eye_stats = None
class evaluator:
def __init__(self, datasets: list[str]):
"""
datasets: iterable of dataset paths
data_dir: directory containing datasets
"""
self.datasets = datasets
self.results = []
def evaluate_datasets(self, plot=False):
## iterate over datasets
# load dataset
for dataset in self.datasets:
model, dataset_name = dataset.split("/")[-2:]
print(f"\nEvaluating model {model} with dataset {dataset_name}")
data, metadata = model_runner.load_from_file(dataset)
result = model_evaluation_result(**metadata)
data = self.prepare_data(data, sps=metadata["sps"])
try:
sym_x, sym_y = metadata["orig_symbols"]
except (TypeError, KeyError, ValueError):
sym_x, sym_y = None, None
self.evaluate_eye(data, result, title=dataset.split("/")[-1], plot=False)
self.evaluate_ser_ber(data, result, sym_x, sym_y)
print("BER:")
self.print_dict(result.bers["regen"])
print()
print("SER:")
self.print_dict(result.sers["regen"])
print()
self.results.append(result)
if plot:
plt.show()
def evaluate_eye(self, data, result, title=None, plot=False):
eye = util.eye_diagram.eye_diagram(
data,
channel_names=[
"fiber_in_x",
"fiber_in_y",
# "fiber_out_x",
# "fiber_out_y",
"fiber_out_x",
"fiber_out_y",
"regen_x",
"regen_y",
],
)
eye.analyse()
eye.plot(title=title or "Eye diagram", show=plot)
result.eye_stats = eye.eye_stats
return eye.eye_stats
...
def evaluate_ser_ber(self, data, result, sym_x=None, sym_y=None):
if result.eye_stats is None:
self.evaluate_eye(data, result)
symbols = []
sers = {"fiber_out": {"x": None, "y": None}, "regen": {"x": None, "y": None}}
bers = {"fiber_out": {"x": None, "y": None}, "regen": {"x": None, "y": None}}
for channel_data, stats in zip(data, result.eye_stats):
timestamps = channel_data[0]
dat = channel_data[1]
channel_name = stats["channel_name"]
if stats["success"]:
thresholds = stats["thresholds"]
time_midpoint = stats["time_midpoint"]
else:
if channel_name.endswith("x"):
thresholds = result.eye_stats[0]["thresholds"]
time_midpoint = result.eye_stats[0]["time_midpoint"]
elif channel_name.endswith("y"):
thresholds = result.eye_stats[1]["thresholds"]
time_midpoint = result.eye_stats[1]["time_midpoint"]
else:
levels = np.linspace(np.min(dat), np.max(dat), 4)
thresholds = util.eye_diagram.eye_diagram.calculate_thresholds(levels)
time_midpoint = 1.0
# time_offset = time_midpoint - 0.5
# # time_offset = 0
# index_offset = np.argmin(np.abs((timestamps - time_offset) % 1.0))
nos = len(timestamps) // result.sps
# idx = np.arange(index_offset, len(timestamps), result.sps).astype(int)
# if time_offset < 0:
# idx = np.insert(idx, 0, 0)
idx = list(range(0,len(timestamps),result.sps))
idx = idx[:nos]
data_sampled = dat[idx]
detected_symbols = self.detect_symbols(data_sampled, thresholds)
symbols.append({"channel_name": channel_name, "symbols": detected_symbols})
symbols_x_gt = sym_x or symbols[0]["symbols"]
symbols_y_gt = sym_y or symbols[1]["symbols"]
symbols_x_fiber_out = symbols[2]["symbols"]
symbols_y_fiber_out = symbols[3]["symbols"]
symbols_x_regen = symbols[4]["symbols"]
symbols_y_regen = symbols[5]["symbols"]
sers["fiber_out"]["x"], bers["fiber_out"]["x"] = self.calculate_ser_ber(symbols_x_gt, symbols_x_fiber_out)
sers["fiber_out"]["y"], bers["fiber_out"]["y"] = self.calculate_ser_ber(symbols_y_gt, symbols_y_fiber_out)
sers["regen"]["x"], bers["regen"]["x"] = self.calculate_ser_ber(symbols_x_gt, symbols_x_regen)
sers["regen"]["y"], bers["regen"]["y"] = self.calculate_ser_ber(symbols_y_gt, symbols_y_regen)
result.sers = sers
result.bers = bers
@staticmethod
def calculate_ser_ber(symbols_gt, symbols):
# levels = 4
# symbol difference -> bit error count
# |rx - tx| = 0 -> 0
# |rx - tx| = 1 -> 1
# |rx - tx| = 2 -> 2
# |rx - tx| = 3 -> 1
# assuming gray coding -> 0: 00, 1: 01, 2: 11, 3: 10
bec_map = {0: 0, 1: 1, 2: 2, 3: 1, np.nan: 2}
ser = {}
ber = {}
ser["n_symbols"] = len(symbols_gt)
ser["n_errors"] = np.sum(symbols != symbols_gt)
ser["total"] = float(ser["n_errors"] / ser["n_symbols"])
bec = np.vectorize(bec_map.get)(np.abs(symbols - symbols_gt))
bit_errors = np.sum(bec)
ber["n_bits"] = len(symbols_gt) * 2
ber["n_errors"] = bit_errors
ber["total"] = float(ber["n_errors"] / ber["n_bits"])
return ser, ber
@staticmethod
def print_dict(d: dict, indent=2, logarithmic=False, level=0):
for key, value in d.items():
if isinstance(value, dict):
print(f"{' ' * indent * level}{key}:")
evaluator.print_dict(value, indent=indent, logarithmic=logarithmic, level=level + 1)
else:
if isinstance(value, float):
if logarithmic:
if value == 0:
value = -np.inf
else:
value = np.log10(value)
print(f"{' ' * indent * level}{key}: {value:.2e}\t", end="")
else:
print(f"{' ' * indent * level}{key}: {value}\t", end="")
print()
@staticmethod
def detect_symbols(samples, thresholds=None):
thresholds = (1 / 6, 3 / 6, 5 / 6) if thresholds is None else thresholds
thresholds = (-np.inf, *thresholds, np.inf)
symbols = np.digitize(samples, thresholds) - 1
return symbols
@staticmethod
def prepare_data(data, sps=None):
data = data.transpose(1, 0)
timestamps = data[-1].real
data = data[:-1]
if sps is not None:
timestamps /= sps
# data = np.stack(
# (
# *data[0:2], # fiber_in_x, fiber_in_y
# # *data_[2:4], # fiber_out_x, fiber_out_y
# *data[4:6], # fiber_out_noisy_x, fiber_out_noisy_y
# *data[6:8], # regen_out_x, regen_out_y
# ),
# axis=0,
# )
data_eye = []
for channel_values in data:
channel_values = np.square(np.abs(channel_values))
data_eye.append(np.stack((timestamps, channel_values), axis=0))
data_eye = np.stack(data_eye, axis=0)
return data_eye
def generate_data(parameters, runner=None):
runner = runner or model_runner()
for param in parameters:
runner.update_parameter_range(*param)
runner.generate_iterations()
print(f"{runner.iter.length} parameter combinations")
runner.generate_datasets()
return runner
if __name__ == "__main__":
model_path = ".models/best_20250110_191149.tar" # D 17, OSNR 100, delta_beta 0.14, baud 10e9
parameters = (
# name, config keys, (min, max), n_steps, log
# ("D", ("fiber", "d"), (28,30), 3, False),
# ("S", ("fiber", "s"), (0, 0.058), 2, False),
("OSNR", ("signal", "osnr"), (20, 40), 5, False),
# ("PMD", ("fiber", "max_delta_beta"), (0, 0.28), 3, False),
# ("Baud", ("glova", "symbolrate"), (10e9, 100e9), 3, True),
)
datasets = None
results = None
# datasets = "tolerance_results/datasets/datasets_list_20250110_223337.h5"
results = "tolerance_results/datasets/results_list_20250110_232639.h5"
runner = model_runner()
# generate_data(parameters, runner)
if results is None:
if datasets is None:
generate_data(parameters, runner)
else:
runner.load_datasets_from_file(datasets)
print(f"{len(runner.datasets)} loaded")
runner.run_model_evaluation(model_path)
else:
runner.load_results_from_file(results)
# print(runner.parameters)
# print(runner.results)
eval = evaluator(runner.results)
eval.evaluate_datasets(plot=True)

View File

@@ -482,6 +482,15 @@ class Identity(nn.Module):
def forward(self, x): def forward(self, x):
return x return x
class phase_shift(nn.Module):
def __init__(self, size):
super(phase_shift, self).__init__()
self.size = size
self.phase = nn.Parameter(torch.rand(size))
def forward(self, x):
return x * torch.exp(1j*self.phase)
class PowRot(nn.Module): class PowRot(nn.Module):
def __init__(self, bias=False): def __init__(self, bias=False):
@@ -531,19 +540,19 @@ def angle_mse_loss(x: torch.Tensor, target: torch.Tensor):
class EOActivation(nn.Module): class EOActivation(nn.Module):
def __init__(self, size=None): def __init__(self, size=None):
# 10.1109/SiPhotonics60897.2024.10543376 # 10.1109/JSTQE.2019.2930455
super(EOActivation, self).__init__() super(EOActivation, self).__init__()
if size is None: if size is None:
raise ValueError("Size must be specified") raise ValueError("Size must be specified")
self.size = size self.size = size
self.alpha = nn.Parameter(torch.ones(size)) self.alpha = nn.Parameter(torch.rand(size))
self.V_bias = nn.Parameter(torch.ones(size)) self.V_bias = nn.Parameter(torch.rand(size))
self.gain = nn.Parameter(torch.ones(size)) self.gain = nn.Parameter(torch.rand(size))
# if bias: # if bias:
# self.phase_bias = nn.Parameter(torch.zeros(size)) # self.phase_bias = nn.Parameter(torch.zeros(size))
# else: # else:
# self.register_buffer("phase_bias", torch.zeros(size)) # self.register_buffer("phase_bias", torch.zeros(size))
self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi) # self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
self.register_buffer("responsivity", torch.ones(size)*0.9) self.register_buffer("responsivity", torch.ones(size)*0.9)
self.register_buffer("V_pi", torch.ones(size)*3) self.register_buffer("V_pi", torch.ones(size)*3)
@@ -551,17 +560,17 @@ class EOActivation(nn.Module):
def reset_weights(self): def reset_weights(self):
if "alpha" in self._parameters: if "alpha" in self._parameters:
self.alpha.data = torch.ones(self.size)*0.5 self.alpha.data = torch.rand(self.size)
if "V_pi" in self._parameters: if "V_pi" in self._parameters:
self.V_pi.data = torch.ones(self.size)*3 self.V_pi.data = torch.rand(self.size)*3
if "V_bias" in self._parameters: if "V_bias" in self._parameters:
self.V_bias.data = torch.zeros(self.size) self.V_bias.data = torch.randn(self.size)
if "gain" in self._parameters: if "gain" in self._parameters:
self.gain.data = torch.ones(self.size) self.gain.data = torch.rand(self.size)
if "responsivity" in self._parameters: if "responsivity" in self._parameters:
self.responsivity.data = torch.ones(self.size)*0.9 self.responsivity.data = torch.ones(self.size)*0.9
if "bias" in self._parameters: # if "bias" in self._parameters:
self.phase_bias.data = torch.zeros(self.size) # self.phase_bias.data = torch.zeros(self.size)
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8) phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
@@ -570,12 +579,11 @@ class EOActivation(nn.Module):
return ( return (
1j 1j
* torch.sqrt(1 - self.alpha) * torch.sqrt(1 - self.alpha)
* torch.exp(-0.5j * (intermediate + self.phase_bias)) * torch.exp(-0.5j * intermediate)
* torch.cos(0.5 * intermediate) * torch.cos(0.5 * intermediate)
* x * x
) )
class Pow(nn.Module): class Pow(nn.Module):
""" """
implements the activation function implements the activation function
@@ -716,6 +724,7 @@ __all__ = [
MZISingle, MZISingle,
EOActivation, EOActivation,
photodiode, photodiode,
phase_shift,
# SaturableAbsorberLambertW, # SaturableAbsorberLambertW,
# SaturableAbsorber, # SaturableAbsorber,
# SpreadLayer, # SpreadLayer,

View File

@@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
import h5py
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
@@ -24,8 +25,22 @@ import multiprocessing as mp
# def __len__(self): # def __len__(self):
# return len(self.indices) # return len(self.indices)
def load_from_file(datapath):
if str(datapath).endswith('.h5'):
symbols = None
with h5py.File(datapath, "r") as infile:
data = infile["data"][:]
try:
symbols = infile["symbols"][:]
except KeyError:
pass
else:
symbols = None
data = np.load(datapath)
return data, symbols
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=1, device=None, dtype=None):
filepath = Path(config_path) filepath = Path(config_path)
filepath = filepath.parent.glob(filepath.name) filepath = filepath.parent.glob(filepath.name)
config = configparser.ConfigParser() config = configparser.ConfigParser()
@@ -41,14 +56,20 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
if symbols is None: if symbols is None:
symbols = int(config["glova"]["nos"]) - skipfirst symbols = int(config["glova"]["nos"]) - skipfirst
data = np.load(datapath)[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)] data, orig_symbols = load_from_file(datapath)
data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)]
orig_symbols = orig_symbols[skipfirst:symbols+skipfirst]
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps)) timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps))
if normalize: data *= np.sqrt(normalize)
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
a, b, c, d = np.square(data.T) # if normalize:
a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d)) # # square gets normalized to 1, as the power is (proportional to) the square of the amplitude
data = np.sqrt(np.array([a, b, c, d]).T) # a, b, c, d = data.T
# a, b, c, d = a - np.min(np.abs(a)), b - np.min(np.abs(b)), c - np.min(np.abs(c)), d - np.min(np.abs(d))
# a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
# data = np.array([a, b, c, d]).T
if real: if real:
data = np.abs(data) data = np.abs(data)
@@ -59,7 +80,7 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
data = torch.tensor(data, device=device, dtype=dtype) data = torch.tensor(data, device=device, dtype=dtype)
return data, config return data, config, orig_symbols
def roll_along(arr, shifts, dim): def roll_along(arr, shifts, dim):
@@ -114,7 +135,8 @@ class FiberRegenerationDataset(Dataset):
dtype: torch.dtype = None, dtype: torch.dtype = None,
real: bool = False, real: bool = False,
device=None, device=None,
osnr: float = None, # osnr: float|None = None,
polarisations = None,
randomise_polarisations: bool = False, randomise_polarisations: bool = False,
repeat_randoms: int = 1, repeat_randoms: int = 1,
**kwargs, **kwargs,
@@ -151,36 +173,26 @@ class FiberRegenerationDataset(Dataset):
self.randomise_polarisations = randomise_polarisations self.randomise_polarisations = randomise_polarisations
faux = kwargs.pop("faux", False)
if faux:
data_raw = np.array(
[[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)],
dtype=np.complex128,
)
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
timestamps = torch.arange(12800)
data_raw = torch.concatenate([data_raw, timestamps.reshape(-1, 1)], axis=-1)
self.config = {
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
"glova": {"sps": 128},
}
else:
data_raw = None data_raw = None
self.config = None self.config = None
files = [] files = []
self.orig_symbols = None
for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]: for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]:
data, config = load_data( data, config, orig_syms = load_data(
file_path, file_path,
skipfirst=drop_first, skipfirst=drop_first,
symbols=kwargs.get("num_symbols", None), symbols=kwargs.get("num_symbols", None),
real=real, real=real,
normalize=True, normalize=1000,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
if orig_syms is not None:
if self.orig_symbols is None:
self.orig_symbols = orig_syms
else:
self.orig_symbols = np.concat((self.orig_symbols, orig_syms), axis=-1)
if data_raw is None: if data_raw is None:
data_raw = data data_raw = data
else: else:
@@ -193,23 +205,18 @@ class FiberRegenerationDataset(Dataset):
self.config["data"]["file"] = str(files) self.config["data"]["file"] = str(files)
# if polarisations is not None: # if polarisations is not None:
# self.angles = torch.tensor(polarisations).repeat(len(data_raw), 1) # data_raw_clone = data_raw.clone()
# for i, angle in enumerate(torch.tensor(np.array(polarisations))): # # rotate the polarisation by 180 degrees
# data_raw_copy = data_raw.clone() # data_raw_clone[2, :] *= -1
# if angle == 0: # data_raw_clone[3, :] *= -1
# continue # data_raw = torch.cat([data_raw, data_raw_clone], dim=0)
# sine = torch.sin(angle)
# cosine = torch.cos(angle) self.polarisations = bool(polarisations)
# data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine
# data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine
# if i == 0:
# data_raw = data_raw_copy
# else:
# data_raw = torch.cat([data_raw, data_raw_copy], dim=0)
self.device = data_raw.device self.device = data_raw.device
self.samples_per_symbol = int(self.config["glova"]["sps"]) self.samples_per_symbol = int(self.config["glova"]["sps"])
# self.num_symbols = int(self.config["glova"]["nos"])
self.samples_per_slice = int(symbols * self.samples_per_symbol) self.samples_per_slice = int(symbols * self.samples_per_symbol)
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
@@ -290,6 +297,34 @@ class FiberRegenerationDataset(Dataset):
fiber_in = torch.cat([fiber_in, timestamps.unsqueeze(0)], dim=0) fiber_in = torch.cat([fiber_in, timestamps.unsqueeze(0)], dim=0)
fiber_out = torch.cat([fiber_out, timestamps.unsqueeze(0)], dim=0) fiber_out = torch.cat([fiber_out, timestamps.unsqueeze(0)], dim=0)
# fiber_out: [E_out_x, E_out_y, timestamps]
# add noise related to amplification necessary due to splitting of the signal
gain_lin = output_dim*2
edfa_nf = float(self.config["signal"]["edfa_nf"])
nf_lin = 10**(edfa_nf/10)
f0 = float(self.config["glova"]["f0"])
noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9
noise = torch.randn_like(fiber_out[:2, :])
noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1))
noise = noise * torch.sqrt(noise_add / noise_power)
fiber_out[:2, :] += noise
# if osnr is None:
# noisy = fiber_out[:2, :]
# else:
# noisy = self.add_noise(fiber_out[:2, :], osnr)
# fiber_out = torch.cat([fiber_out, noisy], dim=0)
# fiber_out: [E_out_x, E_out_y, timestamps, E_out_x_noisy, E_out_y_noisy]
if repeat_randoms > 1: if repeat_randoms > 1:
fiber_in = fiber_in.repeat(1, 1, repeat_randoms) fiber_in = fiber_in.repeat(1, 1, repeat_randoms)
fiber_out = fiber_out.repeat(1, 1, repeat_randoms) fiber_out = fiber_out.repeat(1, 1, repeat_randoms)
@@ -298,27 +333,33 @@ class FiberRegenerationDataset(Dataset):
repeat_randoms = 1 repeat_randoms = 1
if self.randomise_polarisations: if self.randomise_polarisations:
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms), 2) * torch.pi angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms, device=fiber_out.device), 2) * torch.pi
# start_angle = torch.rand(1) * 2 * torch.pi # start_angle = torch.rand(1) * 2 * torch.pi
# angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk # angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
# self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi # self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi
else: else:
angles = torch.zeros(data_raw.shape[-1]) angles = torch.zeros(data_raw.shape[-1], device=fiber_out.device)
sin = torch.sin(angles) sin = torch.sin(angles)
cos = torch.cos(angles) cos = torch.cos(angles)
rot = torch.stack([torch.stack([cos, -sin], dim=1), torch.stack([sin, cos], dim=1)], dim=2) rot = torch.stack([torch.stack([cos, -sin], dim=1), torch.stack([sin, cos], dim=1)], dim=2)
data_rot = torch.bmm(fiber_out[:2, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T data_rot = torch.bmm(fiber_out[:2, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
# data_rot_noisy = torch.bmm(fiber_out[3:5, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
fiber_out = torch.cat((fiber_out, data_rot), dim=0) fiber_out = torch.cat((fiber_out, data_rot), dim=0)
fiber_out = torch.cat([fiber_out, angles.unsqueeze(0)], dim=0) fiber_out = torch.cat([fiber_out, angles.unsqueeze(0)], dim=0)
if osnr is not None: # fiber_in:
popt = torch.mean(fiber_out[:2, :, :].abs().flatten(), dim=-1) # 0 E_in_x,
noise = torch.randn_like(fiber_out[:2, :, :]) # 1 E_in_y,
pn = torch.mean(noise.abs().flatten(), dim=-1) # 2 timestamps
noise = noise * (popt / pn) * 10 ** (-osnr / 20)
fiber_out[:2, :, :] = torch.add(fiber_out[:2, :, :], noise)
# fiber_out:
# 0 E_out_x,
# 1 E_out_y,
# 2 timestamps,
# 3 E_out_x_rot,
# 4 E_out_y_rot,
# 5 angle
@@ -350,6 +391,22 @@ class FiberRegenerationDataset(Dataset):
def __len__(self): def __len__(self):
return self.fiber_in.shape[0] return self.fiber_in.shape[0]
def add_noise(self, data, osnr):
osnr_lin = 10**(osnr/10)
popt = torch.mean(data.abs().square().squeeze(), dim=-1)
noise = torch.randn_like(data)
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
mult = torch.sqrt(popt/(pn*osnr_lin))
mult = mult * torch.eye(popt.shape[0], device=mult.device)
mult = mult.to(dtype=noise.dtype)
noise = mult @ noise
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
noisy = data + noise
return noisy
def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(idx, slice): if isinstance(idx, slice):
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))] return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
@@ -357,14 +414,19 @@ class FiberRegenerationDataset(Dataset):
# fiber in: [E_in_x, E_in_y, timestamps] # fiber in: [E_in_x, E_in_y, timestamps]
# fiber out: [E_out_x, E_out_y, timestamps, E_out_x_rot, E_out_y_rot, angle] # fiber out: [E_out_x, E_out_y, timestamps, E_out_x_rot, E_out_y_rot, angle]
# if self.polarisations:
output_dim = self.output_dim // 2
self.output_dim = output_dim * 2
fiber_in = self.fiber_in[idx].squeeze() fiber_in = self.fiber_in[idx].squeeze()
fiber_out = self.fiber_out[idx].squeeze() fiber_out = self.fiber_out[idx].squeeze()
fiber_in = fiber_in[..., : fiber_in.shape[-1] // self.output_dim * self.output_dim] fiber_in = fiber_in[..., : fiber_in.shape[-1] // output_dim * output_dim]
fiber_out = fiber_out[..., : fiber_out.shape[-1] // self.output_dim * self.output_dim] fiber_out = fiber_out[..., : fiber_out.shape[-1] // output_dim * output_dim]
fiber_in = fiber_in.view(fiber_in.shape[0], output_dim, -1)
fiber_out = fiber_out.view(fiber_out.shape[0], output_dim, -1)
fiber_in = fiber_in.view(fiber_in.shape[0], self.output_dim, -1)
fiber_out = fiber_out.view(fiber_out.shape[0], self.output_dim, -1)
# data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim] # data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
@@ -372,11 +434,36 @@ class FiberRegenerationDataset(Dataset):
# angle = self.angles[idx] # angle = self.angles[idx]
center_angle = fiber_out[5, self.output_dim // 2, 0] # fiber_in:
# 0 E_in_x,
# 1 E_in_y,
# 2 timestamps
# fiber_out:
# 0 E_out_x,
# 1 E_out_y,
# 2 timestamps,
# 3 E_out_x_rot,
# 4 E_out_y_rot,
# 5 angle
center_angle = fiber_out[0, output_dim // 2, 0]
angles = fiber_out[5, :, 0] angles = fiber_out[5, :, 0]
plot_data = fiber_out[:2, self.output_dim // 2, 0].detach().clone() plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone()
plot_data_rot = fiber_out[3:5, self.output_dim // 2, 0].detach().clone() plot_data_rot = fiber_out[3:5, output_dim // 2, 0].detach().clone()
data = fiber_out[3:5, :, 0] data = fiber_out[0:2, :, 0]
# fiber_out_plot_clean = fiber_out[:2, output_dim // 2, 0].detach().clone()
# if self.polarisations:
# rot = int(np.random.randint(2)*2-1)
# pol_flipped_data[0:1, :] = rot*data[0, :]
# pol_flipped_data[1, :] = rot*data[1, :]
# plot_data_rot[0] = rot*plot_data_rot[0]
# plot_data_rot[1] = rot*plot_data_rot[1]
# center_angle = center_angle + (rot - 1) * torch.pi/2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
# angles = angles + (rot - 1) * torch.pi/2
# if self.randomise_polarisations: # if self.randomise_polarisations:
# data = data.mT # data = data.mT
@@ -389,16 +476,27 @@ class FiberRegenerationDataset(Dataset):
# angle = torch.zeros_like(angle) # angle = torch.zeros_like(angle)
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter) # for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter)
angle_data = fiber_out[:2, :, :].reshape(2, -1).mean(dim=1).repeat(1, self.output_dim) # angle_data = fiber_out[:2, :, :].reshape(2, -1).mean(dim=1).repeat(1, output_dim)
angle_data2 = self.complex_max(fiber_out[:2, :, :].reshape(2, -1)).repeat(1, self.output_dim) # angle_data2 = self.complex_max(fiber_out[:2, :, :].reshape(2, -1)).repeat(1, output_dim)
# sop = self.polarimeter(plot_data) # sop = self.polarimeter(plot_data)
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1) # angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
# angle = data_slice[1, 3, self.output_dim // 2, 0].real # angle = data_slice[1, 3, self.output_dim // 2, 0].real
target = fiber_in[:2, self.output_dim // 2, 0] target = fiber_in[:2, output_dim // 2, 0]
plot_target = fiber_in[:2, self.output_dim // 2, 0].detach().clone() plot_target = fiber_in[:2, output_dim // 2, 0].detach().clone()
target_timestamp = fiber_in[2, self.output_dim // 2, 0].real target_timestamp = fiber_in[2, output_dim // 2, 0].real
... ...
if self.polarisations:
rot = int(np.random.randint(2)*2-1)
data = rot*data
target = rot*target
plot_data_rot = rot*plot_data_rot
center_angle = center_angle + (rot - 1) * torch.pi/2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
angles = angles + (rot - 1) * torch.pi/2
pol_flipped_data = -data
pol_flipped_target = -target
# data_timestamps = data[-1,:].real # data_timestamps = data[-1,:].real
# data = data[:-1, :] # data = data[:-1, :]
# target_timestamp = target[-1].real # target_timestamp = target[-1].real
@@ -407,13 +505,15 @@ class FiberRegenerationDataset(Dataset):
# transpose to interleave the x and y data in the output tensor # transpose to interleave the x and y data in the output tensor
data = data.transpose(0, 1).flatten().squeeze() data = data.transpose(0, 1).flatten().squeeze()
angle_data = angle_data.transpose(0, 1).flatten().squeeze() pol_flipped_data = pol_flipped_data.transpose(0, 1).flatten().squeeze()
angle_data2 = angle_data2.transpose(0,1).flatten().squeeze() # angle_data = angle_data.transpose(0, 1).flatten().squeeze()
# angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
center_angle = center_angle.flatten().squeeze() center_angle = center_angle.flatten().squeeze()
angles = angles.flatten().squeeze() angles = angles.flatten().squeeze()
# data_timestamps = data_timestamps.flatten().squeeze() # data_timestamps = data_timestamps.flatten().squeeze()
# target = target.transpose(0,1).flatten().squeeze() # target = target.transpose(0,1).flatten().squeeze()
target = target.flatten().squeeze() target = target.flatten().squeeze()
pol_flipped_target = pol_flipped_target.flatten().squeeze()
target_timestamp = target_timestamp.flatten().squeeze() target_timestamp = target_timestamp.flatten().squeeze()
plot_target = plot_target.flatten().squeeze() plot_target = plot_target.flatten().squeeze()
plot_data = plot_data.flatten().squeeze() plot_data = plot_data.flatten().squeeze()
@@ -421,17 +521,22 @@ class FiberRegenerationDataset(Dataset):
return { return {
"x": data, "x": data,
"x_flipped": pol_flipped_data,
"x_stacked": torch.cat([data, pol_flipped_data], dim=-1),
"y": target, "y": target,
"center_angle": center_angle, "y_flipped": pol_flipped_target,
"angles": angles, "y_stacked": torch.cat([target, pol_flipped_target], dim=-1),
# "center_angle": center_angle,
# "angles": angles,
"mean_angle": angles.mean(), "mean_angle": angles.mean(),
# "sop": sop, # "sop": sop,
"angle_data": angle_data, # "angle_data": angle_data,
"angle_data2": angle_data2, # "angle_data2": angle_data2,
"timestamp": target_timestamp, "timestamp": target_timestamp,
"plot_target": plot_target, "plot_target": plot_target,
"plot_data": plot_data, "plot_data": plot_data,
"plot_data_rot": plot_data_rot, "plot_data_rot": plot_data_rot,
# "plot_clean": fiber_out_plot_clean,
} }
def complex_max(self, data, dim=-1): def complex_max(self, data, dim=-1):

View File

@@ -82,7 +82,6 @@ class eye_diagram:
self.vertical_bins = vertical_bins self.vertical_bins = vertical_bins
self.multi_threaded = multithreaded self.multi_threaded = multithreaded
self.eye_built = False self.eye_built = False
self.analyse()
def generate_eye_data(self): def generate_eye_data(self):
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False) self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
@@ -126,6 +125,7 @@ class eye_diagram:
rows = int(np.ceil(self.channels / cols)) rows = int(np.ceil(self.channels / cols))
fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False) fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False)
fig.suptitle(title) fig.suptitle(title)
fig.tight_layout()
ax = np.atleast_1d(ax).transpose().flatten() ax = np.atleast_1d(ax).transpose().flatten()
for i in range(self.channels): for i in range(self.channels):
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}") ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}")
@@ -147,19 +147,21 @@ class eye_diagram:
yspan = ymax - ymin yspan = ymax - ymin
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan)) ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
if stats and self.eye_stats[i]["success"]: if stats and self.eye_stats[i]["success"]:
# add min_area above the plot # # add min_area above the plot
ax[i].annotate( # ax[i].annotate(
f"Min Area: {self.eye_stats[i]['min_area']:.2e}", # f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
xy=(0.05, ymax + 0.05 * yspan), # xy=(0.05, ymax + 0.05 * yspan),
# xycoords="axes fraction", # # xycoords="axes fraction",
ha="left", # ha="left",
va="center", # va="center",
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"), # bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
) # )
if all_stats: if all_stats:
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--") ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
ax[i].set_yticks(self.eye_stats[i]["levels"]) y_ticks = (*self.eye_stats[i]["levels"],*self.eye_stats[i]["thresholds"])
# y_ticks = np.sort(y_ticks)
ax[i].set_yticks(y_ticks)
# add arrows for amplitudes # add arrows for amplitudes
for j in range(len(self.eye_stats[i]["amplitudes"])): for j in range(len(self.eye_stats[i]["amplitudes"])):
ax[i].annotate( ax[i].annotate(
@@ -193,35 +195,35 @@ class eye_diagram:
except (ValueError, IndexError): except (ValueError, IndexError):
pass pass
# add arrows for eye widths # add arrows for eye widths
for j in range(len(self.eye_stats[i]["widths"])): # for j in range(len(self.eye_stats[i]["widths"])):
try: # try:
left = np.max(self.eye_stats[i]["time_clusters"][j][0]) # left = np.max(self.eye_stats[i]["time_clusters"][j][0])
right = np.min(self.eye_stats[i]["time_clusters"][j][1]) # right = np.min(self.eye_stats[i]["time_clusters"][j][1])
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2 # vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
ax[i].annotate( # ax[i].annotate(
"", # "",
xy=(left, vertical), # xy=(left, vertical),
xytext=(right, vertical), # xytext=(right, vertical),
arrowprops=dict(arrowstyle="<->", facecolor="black"), # arrowprops=dict(arrowstyle="<->", facecolor="black"),
) # )
ax[i].annotate( # ax[i].annotate(
f"{self.eye_stats[i]['widths'][j]:.2e}", # f"{self.eye_stats[i]['widths'][j]:.2e}",
xy=((left + right) / 2 - 0.15, vertical + 0.01), # xy=((left + right) / 2 - 0.15, vertical + 0.01),
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"), # bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
) # )
except (ValueError, IndexError): # except (ValueError, IndexError):
pass # pass
# add area # # add area
for j in range(len(self.eye_stats[i]["areas"])): # for j in range(len(self.eye_stats[i]["areas"])):
horizontal = self.eye_stats[i]["time_midpoint"] # horizontal = self.eye_stats[i]["time_midpoint"]
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2 # vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
ax[i].annotate( # ax[i].annotate(
f"{self.eye_stats[i]['areas'][j]:.2e}", # f"{self.eye_stats[i]['areas'][j]:.2e}",
xy=(horizontal + 0.035, vertical - 0.07), # xy=(horizontal + 0.035, vertical - 0.07),
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"), # bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
) # )
fig.tight_layout() fig.tight_layout()
@@ -229,6 +231,12 @@ class eye_diagram:
plt.show() plt.show()
return fig return fig
@staticmethod
def calculate_thresholds(levels):
ret = np.cumsum(levels, dtype=float)
ret[2:] = ret[2:] - ret[:-2]
return ret[1:]/2
def analyse_single(self, data, index): def analyse_single(self, data, index):
warnings.filterwarnings("error") warnings.filterwarnings("error")
eye_stats = {} eye_stats = {}
@@ -238,12 +246,15 @@ class eye_diagram:
time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels) time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels)
eye_stats["time_midpoint"] = (time_bounds[0] + time_bounds[1]) / 2 eye_stats["time_midpoint_calc"] = (time_bounds[0] + time_bounds[1]) / 2
eye_stats["time_midpoint"] = 1.0
eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels( eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels(
data, approx_levels, time_bounds data, approx_levels, time_bounds
) )
eye_stats["thresholds"] = self.calculate_thresholds(eye_stats["levels"])
eye_stats["amplitudes"] = np.diff(eye_stats["levels"]) eye_stats["amplitudes"] = np.diff(eye_stats["levels"])
eye_stats["heights"] = eye_diagram.calculate_eye_heights( eye_stats["heights"] = eye_diagram.calculate_eye_heights(
@@ -260,22 +271,23 @@ class eye_diagram:
# if not (np.max(eye_stats['time_clusters'][j][0]) < eye_stats["time_midpoint"] < np.min(eye_stats['time_clusters'][j][1])): # if not (np.max(eye_stats['time_clusters'][j][0]) < eye_stats["time_midpoint"] < np.min(eye_stats['time_clusters'][j][1])):
# raise ValueError # raise ValueError
eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"] # eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"]
eye_stats["mean_area"] = np.mean(eye_stats["areas"]) # eye_stats["mean_area"] = np.mean(eye_stats["areas"])
eye_stats["min_area"] = np.min(eye_stats["areas"]) # eye_stats["min_area"] = np.min(eye_stats["areas"])
eye_stats["success"] = True eye_stats["success"] = True
except (RuntimeWarning, UserWarning, ValueError): except (RuntimeWarning, UserWarning, ValueError):
eye_stats["success"] = False eye_stats["success"] = False
eye_stats["time_midpoint"] = 0 eye_stats["time_midpoint"] = None
eye_stats["levels"] = np.zeros(self.n_levels) eye_stats["levels"] = None
eye_stats["amplitude_clusters"] = [] eye_stats["thresholds"] = None
eye_stats["amplitudes"] = np.zeros(self.n_levels - 1) eye_stats["amplitude_clusters"] = None
eye_stats["heights"] = np.zeros(self.n_levels - 1) eye_stats["amplitudes"] = None
eye_stats["widths"] = np.zeros(self.n_levels - 1) eye_stats["heights"] = None
eye_stats["areas"] = np.zeros(self.n_levels - 1) eye_stats["widths"] = None
eye_stats["mean_area"] = 0 # eye_stats["areas"] = np.zeros(self.n_levels - 1)
eye_stats["min_area"] = 0 # eye_stats["mean_area"] = 0
# eye_stats["min_area"] = 0
warnings.resetwarnings() warnings.resetwarnings()
return eye_stats return eye_stats
@@ -441,7 +453,8 @@ if __name__ == "__main__":
data = generate_sample_data(length, noise=0.005) data = generate_sample_data(length, noise=0.005)
eye = eye_diagram(data, horizontal_bins=256, vertical_bins=256) eye = eye_diagram(data, horizontal_bins=256, vertical_bins=256)
attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths", "area", "mean_area", "min_area") eye.analyse()
attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area")
for i, channel in enumerate(eye.eye_stats): for i, channel in enumerate(eye.eye_stats):
print(f"Channel {i}") print(f"Channel {i}")
print_data = {attr: channel[attr] for attr in attrs} print_data = {attr: channel[attr] for attr in attrs}

View File

@@ -1,6 +1,9 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from .datasets import load_data if __name__ == "__main__":
from datasets import load_data
else:
from .datasets import load_data
def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0, width=2, alpha=None, complex=False, show=True): def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0, width=2, alpha=None, complex=False, show=True):
"""Plot an eye diagram for the data given by filepath. """Plot an eye diagram for the data given by filepath.
@@ -20,6 +23,7 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0
raise ValueError("Either path or data and sps must be given.") raise ValueError("Either path or data and sps must be given.")
if path is not None: if path is not None:
data, config = load_data(path, skipfirst, symbols) data, config = load_data(path, skipfirst, symbols)
data = data.detach().cpu().numpy()[:, :4]
sps = int(config["glova"]["sps"]) sps = int(config["glova"]["sps"])
if sps is None: if sps is None:
raise ValueError("sps not set.") raise ValueError("sps not set.")
@@ -71,3 +75,6 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0
plt.show() plt.show()
return fig return fig
if __name__ == "__main__":
eye(path="data/20241229-163838-128-16384-50000-0-0.2-16.8-0.058-PAM4-0-0.16.ini", symbols=1000, width=2, alpha=0.1, complex=False)