model robustness testing

This commit is contained in:
Joseph Hopfmüller
2025-01-10 23:40:54 +01:00
parent 3af73343c1
commit f38d0ca3bb
13 changed files with 1558 additions and 334 deletions

1
.gitignore vendored
View File

@@ -163,3 +163,4 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
tolerance_results/datasets/*

37
notes/models.md Normal file
View File

@@ -0,0 +1,37 @@
# models
## no polarisation flipping
```py
config_path="data/20241229-163*-128-16384-50000-*.ini"
model=".models/best_20241230_011907.tar"
```
```py
config_path="data/20241229-163*-128-16384-80000-*.ini"
model=".models/best_20241230_103752.tar"
```
```py
config_path="data/20241229-163*-128-16384-100000-*.ini"
model=".models/best_20241230_164534.tar"
```
## with polarisation flipping
polarisation flipping: signal is randomly rotated by 180°. polarization rotation can be detected by adding a tone on one of the polarisations, but only to mod 180° with a direct detection setup. the randomly flipped signal should allow the network to hopefully learn to compensate for dispersion, pmd independently from the polarization rot. the training data includes the flipped signal as well, but no indication if the polarisation is flipped.
```py
config_path="data/20241229-163*-128-16384-50000-*.ini"
model=".models/best_20241231_000328.tar"
```
```py
config_path="data/20241229-163*-128-16384-80000-*.ini"
model=".models/best_20241231_163614.tar"
```
```py
config_path="data/20241229-163*-128-16384-100000-*.ini"
model=".models/best_20241231_170532.tar"
```

View File

@@ -124,7 +124,7 @@ class regenerator(Module):
parametrizations: list[dict] = None,
dtype=torch.float64,
dropout_prob=0.01,
scale_layers=False,
prescale=1,
rotate=False,
):
super(regenerator, self).__init__()
@@ -134,15 +134,14 @@ class regenerator(Module):
act_func_kwargs = act_func_kwargs or {}
self.rotation = rotate
self.prescale = prescale
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob)
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob):
for i in range(0, self._n_hidden_layers):
self.add_module(f"layer_{i}", Sequential())
if scale_layers:
self.get_submodule(f"layer_{i}").add_module("scale", Scale(dims[i]))
module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("ONN", module)
@@ -156,8 +155,8 @@ class regenerator(Module):
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
if scale_layers:
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
# if scale_layers:
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module)
@@ -200,6 +199,7 @@ class regenerator(Module):
return powers
def forward(self, x, angle=None, pre_rot=False, trace_powers=False):
x = x * self.prescale
powers = self._trace_powers(trace_powers, x)
# x = self.layer_0(x)
# powers = self._trace_powers(trace_powers, x, powers)

View File

@@ -683,7 +683,7 @@ class RegenerationTrainer:
def define_model(self, model_kwargs=None):
if self.resume:
model_kwargs = self.checkpoint_dict["model_kwargs"]
model_kwargs = None
else:
model_kwargs = model_kwargs
@@ -692,6 +692,14 @@ class RegenerationTrainer:
input_dim = 2 * self.data_settings.output_size
# if self.data_settings.polarisations:
# input_dim *= 2
output_dim = self.model_settings.output_dim
# if self.data_settings.polarisations:
output_dim *= 2
dtype = getattr(torch, self.data_settings.dtype)
afunc = getattr(util.complexNN, self.model_settings.model_activation_func)
@@ -703,7 +711,7 @@ class RegenerationTrainer:
hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)]
self.model_kwargs = {
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
"dims": (input_dim, *hidden_dims, output_dim),
"layer_function": layer_func,
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
"act_function": afunc,
@@ -711,7 +719,7 @@ class RegenerationTrainer:
"parametrizations": layer_parametrizations,
"dtype": dtype,
"dropout_prob": self.model_settings.dropout_prob,
"scale_layers": self.model_settings.scale,
"prescale": self.model_settings.scale,
}
else:
self.model_kwargs = model_kwargs
@@ -745,11 +753,12 @@ class RegenerationTrainer:
num_symbols = None
config_path = self.data_settings.config_path
randomise_polarisations = self.data_settings.randomise_polarisations
polarisations = self.data_settings.polarisations
osnr = self.data_settings.osnr
if override is not None:
num_symbols = override.get("num_symbols", None)
config_path = override.get("config_path", config_path)
# polarisations = override.get("polarisations", polarisations)
polarisations = override.get("polarisations", polarisations)
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
# get dataset
dataset = FiberRegenerationDataset(
@@ -763,6 +772,7 @@ class RegenerationTrainer:
real=not dtype.is_complex,
num_symbols=num_symbols,
randomise_polarisations=randomise_polarisations,
polarisations=polarisations,
osnr = osnr,
)
@@ -832,17 +842,19 @@ class RegenerationTrainer:
running_loss = 0.0
self.model.train()
loader_len = len(train_loader)
x_key = "x_stacked"# if self.data_settings.polarisations else "x"
y_key = "y_stacked"# if self.data_settings.polarisations else "y"
for batch_idx, batch in enumerate(train_loader):
x = batch["x"]
y = batch["y"]
angles = batch["mean_angle"]
x = batch[x_key]
y = batch[y_key]
angle = batch["mean_angle"]
self.model.zero_grad(set_to_none=True)
x, y, angles = (
x, y, angle = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
angles.to(self.pytorch_settings.device),
angle.to(self.pytorch_settings.device),
)
y_pred = self.model(x, -angles)
y_pred = self.model(x, -angle)
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
loss_value = loss.item()
loss.backward()
@@ -886,17 +898,19 @@ class RegenerationTrainer:
self.model.eval()
running_error = 0
x_key = "x_stacked"# if self.data_settings.polarisations else "x"
y_key = "y_stacked"# if self.data_settings.polarisations else "y"
with torch.no_grad():
for _, batch in enumerate(valid_loader):
x = batch["x"]
y = batch["y"]
angles = batch["mean_angle"]
x, y, angles = (
x = batch[x_key]
y = batch[y_key]
angle = batch["mean_angle"]
x, y, angle = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
angles.to(self.pytorch_settings.device),
angle.to(self.pytorch_settings.device),
)
y_pred = self.model(x, -angles)
y_pred = self.model(x, -angle)
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
error_value = error.item()
running_error += error_value
@@ -953,15 +967,17 @@ class RegenerationTrainer:
regen = []
timestamps = []
angles = []
x_key = "x_stacked"# if self.data_settings.polarisations else "x"
y_key = "y_stacked"# if self.data_settings.polarisations else "y"
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for batch in loader:
x = batch["x"]
y = batch["y"]
x = batch[x_key]
y = batch[y_key]
plot_target = batch["plot_target"]
angle = batch["mean_angle"]
center_angle = batch["center_angle"]
# center_angle = batch["center_angle"]
timestamp = batch["timestamp"]
plot_data = batch["plot_data"]
plot_data_rot = batch["plot_data_rot"]
@@ -971,14 +987,16 @@ class RegenerationTrainer:
angle.to(self.pytorch_settings.device),
)
if trace_powers:
y_pred, powers = model(x, angle, True).cpu()
y_pred, powers = model(x, -angle, True).cpu()
else:
y_pred = model(x, angle).cpu()
y_pred = model(x, -angle).cpu()
# x = x.cpu()
# y = y.cpu()
# if self.data_settings.polarisations:
y_pred = y_pred[:, :2]
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y_pred = y_pred[:, y_pred.shape[1]//2, :]
y = y.view(y.shape[0], -1, 2)
# y = y.view(y.shape[0], -1, 2)
# plot_data = plot_data.view(plot_data.shape[0], -1, 2)
# c = torch.cos(-angle).cpu()
# s = torch.sin(-angle).cpu()
@@ -996,7 +1014,7 @@ class RegenerationTrainer:
fiber_in.append(plot_target.squeeze())
regen.append(y_pred.squeeze())
timestamps.append(timestamp.squeeze())
angles.append(center_angle.squeeze())
angles.append(angle.squeeze())
fiber_out = torch.vstack(fiber_out).cpu()
fiber_out_rot = torch.vstack(fiber_out_rot).cpu()
@@ -1352,7 +1370,8 @@ class RegenerationTrainer:
"num_symbols": self.pytorch_settings.batchsize,
"config_path": config_path,
"shuffle": False,
"polarisations": (np.random.rand(1) * np.pi * 2,),
# "polarisations": (np.random.rand(1) * np.pi * 2,),
"polarisations": self.data_settings.polarisations,
"randomise_polarisation": self.data_settings.randomise_polarisations,
}
)
@@ -1366,7 +1385,7 @@ class RegenerationTrainer:
fiber_out_rot = fiber_out_rot.view(-1, 2)
angles = angles.view(-1, 1)
angles = angles.real
angles = torch.fmod(angles, 2 * torch.pi)
angles = torch.fmod(angles, 2*torch.pi)
angles = torch.div(angles, 2*torch.pi)
angles = torch.repeat_interleave(angles, 2, dim=1)

View File

@@ -0,0 +1,253 @@
import os
from matplotlib import pyplot as plt
import numpy as np
import torch
import util
from hypertraining.settings import GlobalSettings, DataSettings, ModelSettings, OptimizerSettings, PytorchSettings
from hypertraining import models
# def move_to_location_in_size(array, location, size):
# array_x, array_y = array.shape
# location_x, location_y = location
# size_x, size_y = size
# left = location_x
# right = size_x - array_x - location_x
# top = location_y
# bottom = size_y - array_y - location_y
# return np.pad(
# array,
# (
# (left, right),
# (top, bottom),
# ),
# constant_values=(-np.inf, -np.inf),
# )
def pad_to_size(array, size):
if not hasattr(size, "__len__"):
size = (size, size)
left = (
(size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0
)
right = (
(size[0] - array.shape[0]) // 2 if size[0] is not None else 0
)
top = (
(size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0
)
bottom = (
(size[1] - array.shape[1]) // 2 if size[1] is not None else 0
)
array: np.ndarray = array
if array.ndim == 2:
return np.pad(
array,
(
(left, right),
(top, bottom),
),
constant_values=(np.nan, np.nan),
)
elif array.ndim == 3:
return np.pad(
array,
(
(left, right),
(top, bottom),
(0,0)
),
constant_values=(np.nan, np.nan),
)
def model_plot(model_path):
torch.serialization.add_safe_globals([
*util.complexNN.__all__,
GlobalSettings,
DataSettings,
ModelSettings,
OptimizerSettings,
PytorchSettings,
models.regenerator,
torch.nn.utils.parametrizations.orthogonal,
])
checkpoint_dict = torch.load(model_path, weights_only=True)
dims = checkpoint_dict["model_kwargs"].pop("dims")
model = models.regenerator(*dims, **checkpoint_dict["model_kwargs"])
model.load_state_dict(checkpoint_dict["model_state_dict"])
model_params = []
plots = []
max_size = np.max(dims)
# max_act_size = np.max(dims[1:])
angles = [None, None]
weights = [None, None]
for num, (layer_name, layer) in enumerate(model.named_children()):
# each layer contains an "ONN" layer and an "activation" layer
# activation layer is approximately the same for all layers and nodes -> rotation by 90 degrees
onn_weights = layer.ONN.weight.T
onn_weights = onn_weights.detach().cpu().numpy()
onn_values = np.abs(onn_weights).real
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
act = layer.activation
act_values = np.ones((act.size, 1))
act_values = np.nan * act_values
act_angles = act.phase.unsqueeze(-1).detach().cpu().numpy()
...
# act_phi_bias = torch.pi * act.V_bias / (act.V_pi + 1e-8)
# act_phi_gain = torch.pi * (act.alpha * act.gain * act.responsivity) / (act.V_pi + 1e-8)
# xs = (0.01, 0.1, 1)
# act_values = np.zeros((act.size, len(xs)*2))
# act_angles = np.zeros((act.size, len(xs)*2))
# act_values[:,:] = np.nan
# act_angles[:,:] = np.nan
# for xi, x in enumerate(xs):
# phi_intermediate = act_phi_gain * x**2 + act_phi_bias
# act_resulting_gain = (
# 1j
# * torch.sqrt(1-act.alpha)
# * torch.exp(-0.5j * phi_intermediate)
# * torch.cos(0.5 * phi_intermediate)
# * x
# )
# act_resulting_gain = act_resulting_gain.detach().cpu().numpy()
# act_values[:, xi*2] = np.abs(act_resulting_gain).real
# act_angles[:, xi*2] = np.mod(np.angle(act_resulting_gain), 2*np.pi).real
# if angles[0] is None or angles[0] > np.min(onn_angles.flatten()):
# angles[0] = np.min(onn_angles.flatten())
# if angles[1] is None or angles[1] < np.max(onn_angles.flatten()):
# angles[1] = np.max(onn_angles.flatten())
# if weights[0] is None or weights[0] > np.min(onn_weights.flatten()):
# weights[0] = np.min(onn_weights.flatten())
# if weights[1] is None or weights[1] < np.max(onn_weights.flatten()):
# weights[1] = np.max(onn_weights.flatten())
model_params.append({layer_name: onn_weights})
plots.append({layer_name: (num, onn_values, onn_angles, act_values, act_angles)})
# fig, axs = plt.subplots(3, len(model_params)*2-1, figsize=(20, 5))
for plot in plots:
layer_name, (num, onn_values, onn_angles, act_values, act_angles) = plot.popitem()
# for_plot[:, :, 0] = (for_plot[:, :, 0] - angles[0]) / (angles[1] - angles[0])
# for_plot[:, :, 1] = (for_plot[:, :, 1] - weights[0]) / (weights[1] - weights[0])
onn_values = np.ma.array(onn_values, mask=np.isnan(onn_values))
onn_values = onn_values - np.min(onn_values)
onn_values = onn_values / np.max(onn_values)
act_values = np.ma.array(act_values, mask=np.isnan(act_values))
act_values = act_values - np.min(act_values)
act_values = act_values / np.max(act_values)
onn_values = onn_values
onn_values = pad_to_size(onn_values, (max_size, None))
act_values = act_values
act_values = pad_to_size(act_values, (max_size, 3))
onn_angles = onn_angles / np.pi
onn_angles = pad_to_size(onn_angles, (max_size, None))
act_angles = act_angles / np.pi
act_angles = pad_to_size(act_angles, (max_size, 3))
# onn_angles = onn_angles - np.min(onn_angles)
# onn_angles = onn_angles / np.max(onn_angles)
# act_angles = act_angles - np.min(act_angles)
# act_angles = act_angles / np.max(act_angles)
if num == 0:
value_img = np.concatenate((onn_values, act_values), axis=1)
angle_img = np.concatenate((onn_angles, act_angles), axis=1)
else:
value_img = np.concatenate((value_img, onn_values, act_values), axis=1)
angle_img = np.concatenate((angle_img, onn_angles, act_angles), axis=1)
# -np.inf to np.nan
# value_img[value_img == -np.inf] = np.nan
# angle_img += move_to_location_in_size(onn_angles, ((max_size+3)*num, 0), img_overall_size)
# angle_img += move_to_location_in_size(act_angles, ((max_size+3)*(num+1) + 2, 0), img_overall_size)
from cmcrameri import cm
from matplotlib import colors as mcolors
alpha_map = mcolors.LinearSegmentedColormap(
'alphamap',
{
'red': [(0, 0, 0), (1, 0, 0)],
'green': [(0, 0, 0), (1, 0, 0)],
'blue': [(0, 0, 0), (1, 0, 0)],
'alpha': [
(0, 1, 1),
# (0.2, 0.2, 0.1),
(1, 0, 0)
]
}
)
alpha_map.set_bad(color="#AAAAAA")
fig, axs = plt.subplots(3, 1, sharex=True, sharey=True, figsize=(7, 8.5))
fig.tight_layout()
# masked_value_img = np.ma.masked_where(np.isnan(value_img), value_img)
masked_value_img = value_img
cmap = cm.batlowW
cmap.set_bad(color="#AAAAAA")
im_val = axs[0].imshow(masked_value_img, cmap=cmap)
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
cmap = cm.romaO
cmap.set_bad(color="#AAAAAA")
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap)
im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap)
im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map)
axs[0].axis("off")
axs[1].axis("off")
axs[2].axis("off")
axs[0].set_title("Values")
axs[1].set_title("Angles")
axs[2].set_title("Values and Angles")
...
plt.show()
# model = models.regenerator(*dims, **model_kwargs)
if __name__ == "__main__":
model_plot(".models/best_20250105_145719.tar")

View File

@@ -1,4 +1,4 @@
from datetime import datetime
# from datetime import datetime
from pathlib import Path
import matplotlib
import numpy as np
@@ -13,7 +13,7 @@ from hypertraining.settings import (
OptimizerSettings,
)
from hypertraining.training import RegenerationTrainer, PolarizationTrainer
from hypertraining.training import RegenerationTrainer#, PolarizationTrainer
# import torch
import json
@@ -27,7 +27,7 @@ global_settings = GlobalSettings(
data_settings = DataSettings(
# config_path="data/*-128-16384-1-0-0-0-0-PAM4-0-0.ini",
config_path="data/*-128-16384-10000-0-0-17-0-PAM4-0.ini",
config_path="data/20250110-190528-128-16384-100000-0-0.2-17.0-0.058-PAM4-0-0.14-10.ini",
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
@@ -37,12 +37,13 @@ data_settings = DataSettings(
shuffle=True,
drop_first=64,
train_split=0.8,
randomise_polarisations=True,
osnr=10,
randomise_polarisations=False,
polarisations=True,
osnr=16, #16dB due to amplification with NF 5
)
pytorch_settings = PytorchSettings(
epochs=10000,
epochs=1000,
batchsize=2**14,
device="cuda",
dataloader_workers=24,
@@ -64,11 +65,11 @@ model_settings = ModelSettings(
# "n_hidden_nodes_3": 4,
# "n_hidden_nodes_4": 2,
},
model_activation_func="EOActivation",
model_activation_func="phase_shift",
dropout_prob=0,
model_layer_function="ONNRect",
model_layer_kwargs={"square": True},
scale=False,
scale=2.0,
model_layer_parametrizations=[
{
"tensor_name": "weight",
@@ -77,13 +78,17 @@ model_settings = ModelSettings(
{
"tensor_name": "alpha",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 1,
},
},
{
"tensor_name": "gain",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": float("inf"),
"max": None,
},
},
{
@@ -95,8 +100,12 @@ model_settings = ModelSettings(
},
},
{
"tensor_name": "scales",
"tensor_name": "scale",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2,
},
},
{
"tensor_name": "angle",
@@ -244,9 +253,17 @@ if __name__ == "__main__":
pytorch_settings=pytorch_settings,
model_settings=model_settings,
optimizer_settings=optimizer_settings,
checkpoint_path=".models/best_20241216_221359.tar",
# checkpoint_path=".models/best_20250104_191428.tar",
reset_epoch=True,
# settings_override={
# "data_settings": {
# "config_path": "data/20241229-163*-128-16384-100000-*.ini",
# "polarisations": True,
# },
# "model_settings": {
# "scale": 2.0,
# }
# }
# "optimizer_settings": {
# "optimizer_kwargs": {
# "lr": 0.01,

View File

@@ -16,16 +16,17 @@ from datetime import datetime
import hashlib
from pathlib import Path
import time
import h5py
from matplotlib import pyplot as plt # noqa: F401
import numpy as np
import add_pypho # noqa: F401
from . import add_pypho # noqa: F401
import pypho
default_config = f"""
[glova]
nos = 256
sps = 256
sps = 128
nos = 16384
f0 = 193414489032258.06
symbolrate = 10e9
wisdom_dir = "{str((Path.home() / ".pypho"))}"
@@ -37,9 +38,9 @@ length = 10000
gamma = 1.14
alpha = 0.2
D = 17
S = 0
birefsteps = 0
max_delta_beta = 0.4
S = 0.058
bireflength = 10
max_delta_beta = 0.14
; birefseed = 0xC0FFEE
[signal]
@@ -47,17 +48,15 @@ max_delta_beta = 0.4
modulation = "pam"
mod_order = 4
mod_depth = 0.8
mod_depth = 1
max_jitter = 0.02
; jitter_seed = 0xC0FFEE
laser_power = 0
edfa_power = 3
edfa_power = 0
edfa_nf = 5
pulse_shape = "gauss"
fwhm = 0.33
osnr = "inf"
[data]
dir = "data"
@@ -71,6 +70,7 @@ def get_config(config_file=None):
"""
if config_file is None:
config_file = Path(__file__).parent / "signal_generation.ini"
config_file = Path(config_file)
if not config_file.exists():
with open(config_file, "w") as f:
f.write(default_config)
@@ -83,7 +83,10 @@ def get_config(config_file=None):
conf[section] = {}
for key in config[section]:
# print(f"{key} = {config[section][key]}")
try:
conf[section][key] = eval(config[section][key])
except NameError:
conf[section][key] = float(config[section][key])
# if isinstance(conf[section][key], str):
# conf[section][key] = config[section][key].strip('"')
return conf
@@ -96,7 +99,9 @@ class PDM_IM_IPM:
mod_order=8,
seed=None,
):
assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, "mod_order must be a cube of an integer greater than 1"
assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, (
"mod_order must be a cube of an integer greater than 1"
)
self.glova = glova
self.mod_order = mod_order
self.symbols_per_dim = int(np.cbrt(mod_order))
@@ -110,14 +115,7 @@ class PDM_IM_IPM:
class pam_generator:
def __init__(
self,
glova,
mod_order=None,
mod_depth=0.5,
pulse_shape="gauss",
fwhm=0.33,
seed=None,
single_channel=False
self, glova, mod_order=None, mod_depth=0.5, pulse_shape="gauss", fwhm=0.33, seed=None, single_channel=False
) -> None:
self.glova = glova
self.pulse_shape = pulse_shape
@@ -138,9 +136,7 @@ class pam_generator:
symbols_x = symbols[0] / (self.mod_order)
diffs_x = np.diff(symbols_x, prepend=symbols_x[0])
digital_x = self.generate_digital_signal(diffs_x, max_jitter)
digital_x = np.pad(
digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)
)
digital_x = np.pad(digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
# create analog signal of diff of symbols
E_x = np.convolve(digital_x, wavelet)
@@ -158,16 +154,13 @@ class pam_generator:
symbols_y = symbols[1] / (self.mod_order)
diffs_y = np.diff(symbols_y, prepend=symbols_y[0])
digital_y = self.generate_digital_signal(diffs_y, max_jitter)
digital_y = np.pad(
digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)
)
digital_y = np.pad(digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
E_y = np.convolve(digital_y, wavelet)
E_y = np.cumsum(E_y) * self.modulation_depth + (1 - self.modulation_depth)
E_y = E_y[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
E[0]["E"][1] = np.sqrt(np.multiply(np.square(E[0]["E"][1]), E_y))
# rotate the signal on the y-polarisation by 90°
@@ -176,7 +169,6 @@ class pam_generator:
E[0]["E"][1] = np.zeros_like(E[0]["E"][0], dtype=E[0]["E"][0].dtype)
return E
def generate_digital_signal(self, symbols, max_jitter=0):
rs = np.random.RandomState(self.seed)
signal = np.zeros(self.glova.nos * self.glova.sps)
@@ -198,15 +190,11 @@ class pam_generator:
endpoint=True,
)
sigma = self.fwhm / (1 * np.sqrt(2 * np.log(2))) * self.glova.sps
pulse = (
1
/ (sigma * np.sqrt(2 * np.pi))
* np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
)
pulse = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
return pulse
def initialize_fiber_and_data(config, input_data_override=None):
def initialize_fiber_and_data(config):
py_glova = pypho.setup(
nos=config["glova"]["nos"],
sps=config["glova"]["sps"],
@@ -221,22 +209,14 @@ def initialize_fiber_and_data(config, input_data_override=None):
c_data = pypho.cfiber.DataWrapper(py_glova.sps * py_glova.nos)
py_edfa = pypho.oamp(py_glova, Pmean=config["signal"]["edfa_power"], NF=config["signal"]["edfa_nf"])
if input_data_override is not None:
c_data.E_in = input_data_override[0]
noise = input_data_override[1]
else:
config["signal"]["seed"] = config["signal"].get(
"seed", (int(time.time() * 1000)) % 2**32
)
config["signal"]["jitter_seed"] = config["signal"].get(
"jitter_seed", (int(time.time() * 1000)) % 2**32
)
osnr = config["signal"]["osnr"] = float(config["signal"].get("osnr", "inf"))
config["signal"]["seed"] = config["signal"].get("seed", (int(time.time() * 1000)) % 2**32)
config["signal"]["jitter_seed"] = config["signal"].get("jitter_seed", (int(time.time() * 1000)) % 2**32)
symbolsrc = pypho.symbols(
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
)
laser = pypho.lasmod(
py_glova, power=config["signal"]["laser_power"]+1.5, Df=0, theta=np.pi / 4
)
laser = pypho.lasmod(py_glova, power=config["signal"]["laser_power"], Df=0, theta=np.pi / 4)
modulator = pam_generator(
py_glova,
mod_depth=config["signal"]["mod_depth"],
@@ -252,14 +232,28 @@ def initialize_fiber_and_data(config, input_data_override=None):
symbols_y[:3] = 0
# symbols_x += 1
cw = laser()
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
if osnr != float("inf"):
osnr_lin = 10 ** (osnr / 10)
signal_power = np.sum(pypho.functions.getpower_W(source_signal[0]["E"]))
noise_power = signal_power / osnr_lin
noise = np.random.normal(0, 1, source_signal[0]["E"].shape) + 1j * np.random.normal(
0, 1, source_signal[0]["E"].shape
)
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
noise = noise * np.sqrt(noise_power / noise_power_is)
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
source_signal[0]["E"] += noise
source_signal[0]["noise"] = noise_power_is
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
source_signal = py_edfa(E=source_signal)
nf = py_edfa.NF
source_signal = py_edfa(E=source_signal, NF=0)
py_edfa.NF = nf
c_data.E_in = source_signal[0]["E"]
noise = source_signal[0]["noise"]
@@ -273,25 +267,21 @@ def initialize_fiber_and_data(config, input_data_override=None):
S=config["fiber"]["s"],
)
if config["fiber"].get("birefsteps", 0) > 0:
seed = config["fiber"].get(
"birefseed", (int(time.time() * 1000)) % 2**32
)
seed = config["fiber"].get("birefseed", (int(time.time() * 1000)) % 2**32)
py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre(
py_fiber.l,
py_fiber.l / config["fiber"]["birefsteps"],
# maxDeltaD=config["fiber"]["d"]/5,
maxDeltaBeta = config["fiber"].get("max_delta_beta", 0),
maxDeltaBeta=config["fiber"].get("max_delta_beta", 0),
seed=seed,
)
c_params = pypho.cfiber.ParamsWrapper.from_fiber(
py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200
)
c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200)
c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova)
return c_fiber, c_data, noise, py_edfa
return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y)
def save_data(data, config):
def save_data(data, config, **metadata):
data_dir = Path(config["data"]["dir"])
npy_dir = config["data"].get("npy_dir", "")
save_dir = data_dir / npy_dir if len(npy_dir) else data_dir
@@ -306,6 +296,7 @@ def save_data(data, config):
seed = config["signal"].get("seed", False)
jitter_seed = config["signal"].get("jitter_seed", False)
birefseed = config["fiber"].get("birefseed", False)
osnr = float(config["signal"].get("osnr", "inf"))
config_content = "\n".join((
f"; Generated by {str(Path(__file__).name)} @ {timestamp.strftime('%Y-%m-%d %H:%M:%S')}",
@@ -317,14 +308,14 @@ def save_data(data, config):
f'wisdom_dir = "{config["glova"]["wisdom_dir"]}"',
f'flags = "{config["glova"]["flags"]}"',
f"nthreads = {config['glova']['nthreads']}",
" ",
"",
"[fiber]",
f"length = {config['fiber']['length']}",
f"gamma = {config['fiber']['gamma']}",
f"alpha = {config['fiber']['alpha']}",
f"D = {config['fiber']['d']}",
f"S = {config['fiber']['s']}",
f"birefsteps = {config['fiber'].get('birefsteps',0)}",
f"birefsteps = {config['fiber'].get('birefsteps', 0)}",
f"max_delta_beta = {config['fiber'].get('max_delta_beta', 0)}",
f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set",
"",
@@ -334,49 +325,62 @@ def save_data(data, config):
f'modulation = "{config["signal"]["modulation"]}"',
f"mod_order = {config['signal']['mod_order']}",
f"mod_depth = {config['signal']['mod_depth']}",
""
"",
f"max_jitter = {config['signal']['max_jitter']}",
f"jitter_seed = {hex(jitter_seed)}" if jitter_seed else "; jitter_seed = not set",
""
"",
f"laser_power = {config['signal']['laser_power']}",
f"edfa_power = {config['signal']['edfa_power']}",
f"edfa_nf = {config['signal']['edfa_nf']}",
""
f"osnr = {osnr}",
"",
f'pulse_shape = "{config["signal"]["pulse_shape"]}"',
f"fwhm = {config['signal']['fwhm']}",
"",
"[data]",
f'dir = "{str(data_dir)}"',
f'npy_dir = "{npy_dir}"',
"file = "
"file = ",
))
config_hash = hashlib.md5(config_content.encode()).hexdigest()
save_file = f"{config_hash}.npy"
save_file = f"{config_hash}.h5"
config_content += f'"{str(save_file)}"\n'
filename_components = (
timestamp.strftime("%Y%m%d-%H%M%S"),
config["glova"]["sps"],
config["glova"]["nos"],
config["signal"]["osnr"],
config["fiber"]["length"],
config["fiber"]["gamma"],
config["fiber"]["alpha"],
config["fiber"]["d"],
config["fiber"]["s"],
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
config['fiber'].get('birefsteps',0),
config["fiber"].get("birefsteps", 0),
config["fiber"].get("max_delta_beta", 0),
int(config["glova"]["symbolrate"] / 1e9),
)
lookup_file = "-".join(map(str, filename_components)) + ".ini"
with open(data_dir / lookup_file, "w") as f:
config_filename = data_dir / lookup_file
with open(config_filename, "w") as f:
f.write(config_content)
np.save(save_dir / save_file, save_data)
with h5py.File(save_dir / save_file, "w") as outfile:
outfile.create_dataset("data", data=save_data)
outfile.create_dataset("symbols", data=metadata.pop("symbols"))
for key, value in metadata.items():
# if isinstance(value, dict):
# value = json.dumps(model_runner.convert_arrays(value))
outfile.attrs[key] = value
# np.save(save_dir / save_file, save_data)
print("Saved config to", data_dir / lookup_file)
print("Saved config to", config_filename)
print("Saved data to", save_dir / save_file)
return config_filename
def length_loop(config, lengths, save=True):
lengths = sorted(lengths)
@@ -386,23 +390,19 @@ def length_loop(config, lengths, save=True):
cfiber, cdata, noise, edfa = initialize_fiber_and_data(config)
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
cfiber()
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)")
print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)")
print(
f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)"
)
print(
f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)"
)
E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}]
E_tmp = [{"E": cdata.E_out, "noise": noise * (-cfiber.params.l * cfiber.params.alpha)}]
E_tmp = edfa(E=E_tmp)
cdata.E_out = E_tmp[0]['E']
cdata.E_out = E_tmp[0]["E"]
mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out))
print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)")
if save:
save_data(cdata, config)
@@ -411,27 +411,57 @@ def length_loop(config, lengths, save=True):
def single_run_with_plot(config, save=True):
cfiber, cdata, noise, edfa = initialize_fiber_and_data(config)
cfiber, cdata, config_filename = single_run(config, save)
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
print(
f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)"
)
in_out_eyes(cfiber, cdata, show_pols=False)
return config_filename
def single_run(config, save=True):
cfiber, cdata, noise, edfa, symbols = initialize_fiber_and_data(config)
# mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
# print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)")
# estimate osnr
# noise_power = np.mean(noise)
# osnr_lin = mean_power_in / noise_power - 1
# osnr = 10 * np.log10(osnr_lin)
# print(f"Estimated OSNR: {osnr:.3f} dB")
cfiber()
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
print(
f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)"
)
# mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
# print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)")
E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}]
# noise = noise * np.exp(-cfiber.params.l * cfiber.params.alpha)
# estimate osnr
# noise_power = np.mean(noise)
# osnr_lin = mean_power_out / noise_power - 1
# osnr = 10 * np.log10(osnr_lin)
# print(f"Estimated OSNR: {osnr:.3f} dB")
E_tmp = [{"E": cdata.E_out, "noise": noise}]
E_tmp = edfa(E=E_tmp)
cdata.E_out = E_tmp[0]['E']
if save:
save_data(cdata, config)
cdata.E_out = E_tmp[0]["E"]
# noise = E_tmp[0]["noise"]
# mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out))
# print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)")
# estimate osnr
# noise_power = np.mean(noise)
# osnr_lin = mean_power_amp / noise_power - 1
# osnr = 10 * np.log10(osnr_lin)
# print(f"Estimated OSNR: {osnr:.3f} dB")
config_filename = None
symbols = np.array(symbols)
if save:
config_filename = save_data(cdata, config, **{"symbols": symbols})
return cfiber,cdata,config_filename
in_out_eyes(cfiber, cdata, show_pols=False)
def in_out_eyes(cfiber, cdata, show_pols=False):
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
@@ -595,9 +625,7 @@ def plot_eye_diagram(
signal = signal[: head * eye_width]
if normalize:
signal = signal / np.max(signal)
slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[
offset % (eye_width + 1) :: eye_width
]
slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[offset % (eye_width + 1) :: eye_width]
plt_ax = np.arange(-eye_width // 2, eye_width // 2 + 1) / samplerate
for slice in slices:
ax.plot(plt_ax, slice, color=color, alpha=0.1)
@@ -618,14 +646,26 @@ if __name__ == "__main__":
# lengths = [*lengths, *lengths]
lengths = (
# 8000, 9000,
10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000,
95000, 100000, 105000, 110000, 115000, 120000
10000,
20000,
30000,
40000,
50000,
60000,
70000,
80000,
90000,
95000,
100000,
105000,
110000,
115000,
120000,
)
# lengths = (10000,100000)
length_loop(config, lengths, save=True)
# length_loop(config, lengths, save=True)
# birefringence is constant over coupling length -> several 100m -> bireflength=1000 (m)
# single_run_with_plot(config, save=False)
single_run_with_plot(config, save=False)

View File

@@ -0,0 +1,723 @@
"""
tests a given model for tolerance against variations in
- fiber length
- baudrate
- OSNR
CD, PMD, baudrate need different datasets, osnr is modeled as awgn added to the data before feeding into the model
"""
from datetime import datetime
from typing import Literal
from matplotlib import pyplot as plt
import numpy as np
from pathlib import Path
import h5py
import torch
import util
from hypertraining.settings import GlobalSettings, DataSettings, ModelSettings, OptimizerSettings, PytorchSettings
from hypertraining import models
from signal_gen.generate_signal import single_run, get_config
import json
class NestedParameterIterator:
def __init__(self, parameters):
"""
parameters: dict with key <param_name> and value <dict with keys "config" and "range">
"""
# self.parameters = parameters
self.names = []
self.ranges = []
self.configs = []
for k, v in parameters.items():
self.names.append(k)
self.ranges.append(v["range"])
self.configs.append(v["config"])
self.n_parameters = len(self.ranges)
self.idx = 0
self.range_idx = [0] * self.n_parameters
self.range_len = [len(r) for r in self.ranges]
self.length = int(np.prod(self.range_len))
self.out = []
for i in range(self.length):
self.out.append([])
for j in range(self.n_parameters):
element = {self.names[j]: {"value": self.ranges[j][self.range_idx[j]], "config": self.configs[j]}}
self.out[i].append(element)
self.range_idx[-1] += 1
# update range_idx back to front
for j in range(self.n_parameters - 1, -1, -1):
if self.range_idx[j] == self.range_len[j]:
self.range_idx[j] = 0
self.range_idx[j - 1] += 1
...
def __next__(self):
if self.idx == self.length:
raise StopIteration
self.idx += 1
return self.out[self.idx - 1]
def __iter__(self):
return self
class model_runner:
def __init__(
self,
# length_range: tuple[int | float] = (50e3, 50e3),
# length_steps: int = 1,
# length_log: bool = False,
# baudrate_range: tuple[int | float] = (10e9, 10e9),
# baudrate_steps: int = 1,
# baudrate_log: bool = False,
# osnr_range: tuple[int | float] = (16, 16),
# osnr_steps: int = 1,
# osnr_log: bool = False,
# dataset_dir: str = "data",
# dataset_datetime_glob: str = "*",
results_dir: str = "tolerance_results/datasets",
# model_dir: str = ".models",
config: str = "signal_generation.ini",
config_dir: str = None,
debug: bool = False,
):
"""
length_range: lower and upper limit of length, in meters
length_step: step size of length, in meters
baudrate_range: lower and upper limit of baudrate, in Bd
baudrate_step: step size of baudrate, in Bd
osnr_range: lower and upper limit of osnr, in dB
osnr_step: step size of osnr, in dB
dataset_dir: directory containing datasets
dataset_datetime_glob: datetime glob pattern for dataset files
results_dir: directory to save results
model_dir: directory containing models
"""
self.debug = debug
self.parameters = {}
self.iter = None
# self.update_length_range(length_range, length_steps, length_log)
# self.update_baudrate_range(baudrate_range, baudrate_steps, baudrate_log)
# self.update_osnr_range(osnr_range, osnr_steps, osnr_log)
# self.data_dir = Path(dataset_dir)
# self.data_datetime_glob = dataset_datetime_glob
self.results_dir = Path(results_dir)
# self.model_dir = Path(model_dir)
config_dir = config_dir or Path(__file__).parent
self.config = config_dir / config
torch.serialization.add_safe_globals([
*util.complexNN.__all__,
GlobalSettings,
DataSettings,
ModelSettings,
OptimizerSettings,
PytorchSettings,
models.regenerator,
torch.nn.utils.parametrizations.orthogonal,
])
self.load_model()
self.datasets = []
# def register_parameter(self, name, config):
# self.parameters.append({"name": name, "config": config})
def load_results_from_file(self, path):
data, meta = self.load_from_file(path)
self.results = [d.decode() for d in data]
self.parameters = meta["parameters"]
...
def load_datasets_from_file(self, path):
data, meta = self.load_from_file(path)
self.datasets = [d.decode() for d in data]
self.parameters = meta["parameters"]
...
def update_parameter_range(self, name, config, range, steps, log):
self.parameters[name] = {"config": config, "range": self.update_range(*range, steps, log)}
def generate_iterations(self):
if len(self.parameters) == 0:
raise ValueError("No parameters registered")
self.iter = NestedParameterIterator(self.parameters)
def generate_datasets(self):
# get base config
config = get_config(self.config)
if self.iter is None:
self.generate_iterations()
for params in self.iter:
current_settings = []
# params is a list of dictionaries with keys "name", containing a dict with keys "value", "config"
for param in params:
for name, settings in param.items():
current_settings.append({name: settings["value"]})
self.nested_set(config, settings["config"], settings["value"])
settings_strs = []
for setting in current_settings:
name = list(setting)[0]
settings_strs.append(f"{name}: {float(setting[name]):.2e}")
settings_str = ", ".join(settings_strs)
print(f"Generating dataset for [{settings_str}]")
# TODO look for existing datasets
_, _, path = single_run(config)
self.datasets.append(str(path))
datasets_list_path = self.build_path("datasets_list", parent_dir=self.results_dir, timestamp="back")
metadata = {"parameters": self.parameters}
data = np.array(self.datasets, dtype="S")
self.save_to_file(datasets_list_path, data, **metadata)
@staticmethod
def nested_set(dic, keys, value):
for key in keys[:-1]:
dic = dic.setdefault(key, {})
dic[keys[-1]] = value
## Dataset and model loading
# def find_datasets(self, data_dir=None, data_datetime_glob=None):
# # date-time-sps-nos-length-gamma-alpha-D-S-PAM4-birefsteps-deltabeta-symbolrate.ini
# data_dir = data_dir or self.data_dir
# data_datetime_glob = data_datetime_glob or self.data_datetime_glob
# self.datasets = {}
# data_dir = Path(data_dir)
# for length in self.lengths:
# for baudrate in self.baudrates:
# # dataset_glob = self.data_datetime_glob + f"*-*-{int(length)}-*-*-*-*-PAM4-*-*-{int(baudrate/1e9)}.ini"
# dataset_glob = data_datetime_glob + f"-*-*-{int(length)}-*-*-*-*-PAM4-*-*.ini"
# datasets = [f for f in data_dir.glob(dataset_glob)]
# if len(datasets) == 0:
# continue
# self.datasets[length] = {}
# if len(datasets) > 1:
# print(
# f"multiple datasets found for [{length / 1000:.1f} km, {int(baudrate / 1e9)} GBd]. Using the newest dataset."
# )
# # get newest file from creation date
# datasets.sort(key=lambda x: x.stat().st_ctime)
# self.datasets[length][baudrate] = str(datasets[-1])
def load_dataset(self, dataset_path):
if self.checkpoint_dict is None:
raise ValueError("Model must be loaded before dataset")
if self.dataset_path is None:
self.dataset_path = dataset_path
elif self.dataset_path == dataset_path:
return
symbols = self.checkpoint_dict["settings"]["data_settings"].symbols
data_size = self.checkpoint_dict["settings"]["data_settings"].output_size
dtype = getattr(torch, self.checkpoint_dict["settings"]["data_settings"].dtype)
drop_first = self.checkpoint_dict["settings"]["data_settings"].drop_first
randomise_polarisations = self.checkpoint_dict["settings"]["data_settings"].randomise_polarisations
polarisations = self.checkpoint_dict["settings"]["data_settings"].polarisations
num_symbols = None
if self.debug:
num_symbols = 1000
config_path = Path(dataset_path)
dataset = util.datasets.FiberRegenerationDataset(
file_path=config_path,
symbols=symbols,
output_dim=data_size,
drop_first=drop_first,
dtype=dtype,
real=not dtype.is_complex,
randomise_polarisations=randomise_polarisations,
polarisations=polarisations,
num_symbols=num_symbols,
# device="cuda" if torch.cuda.is_available() else "cpu",
)
self.dataloader = torch.utils.data.DataLoader(
dataset, batch_size=2**14, pin_memory=True, num_workers=24, prefetch_factor=8, shuffle=False
)
return self.dataloader.dataset.orig_symbols
# run model
# return results as array: [fiber_in, fiber_out, fiber_out_noisy, regen_out]
def load_model(self, model_path: str | None = None):
if model_path is None:
self.model = None
self.model_path = None
self.checkpoint_dict = None
return
path = Path(model_path)
if self.model_path is None:
self.model_path = path
elif path == self.model_path:
return
self.dataset_path = None # reset dataset path, as the shape depends on the model
self.checkpoint_dict = torch.load(path, weights_only=True)
dims = self.checkpoint_dict["model_kwargs"].pop("dims")
self.model = models.regenerator(*dims, **self.checkpoint_dict["model_kwargs"])
self.model.load_state_dict(self.checkpoint_dict["model_state_dict"])
## Model evaluation
def run_model_evaluation(self, model_path: str, datasets: str | None = None):
self.load_model(model_path)
# iterate over datasets and osnr values:
# load dataset, add noise, run model, return results
# save results to file
self.results = []
if datasets is not None:
self.load_datasets_from_file(datasets)
n_datasets = len(self.datasets)
for i, dataset_path in enumerate(self.datasets):
conf = get_config(dataset_path)
mpath = Path(model_path)
model_base = mpath.stem
print(f"({1+i}/{n_datasets}) Running model {model_base} with dataset {dataset_path.split('/')[-1]}")
results_path = self.build_path(
dataset_path.split("/")[-1], parent_dir=Path(self.results_dir) / model_base
)
orig_symbols = self.load_dataset(dataset_path)
data, loss = self.run_model()
metadata = {
"model_path": model_path,
"dataset_path": dataset_path,
"loss": loss,
"sps": conf["glova"]["sps"],
"orig_symbols": orig_symbols
# "config": conf,
# "checkpoint_dict": self.checkpoint_dict,
# "nos": self.dataloader.dataset.num_symbols,
}
self.save_to_file(results_path, data, **metadata)
self.results.append(str(results_path))
results_list_path = self.build_path("results_list", parent_dir=self.results_dir, timestamp="back")
metadata = {"parameters": self.parameters}
data = np.array(self.results, dtype="S")
self.save_to_file(results_list_path, data, **metadata)
def run_model(self):
loss = 0
datas = []
self.model.eval()
model = self.model.to("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
for batch in self.dataloader:
x = batch["x_stacked"]
y = batch["y_stacked"]
fiber_in = batch["plot_target"]
# fiber_out = batch["plot_clean"]
fiber_out = batch["plot_data"]
timestamp = batch["timestamp"]
angle = batch["mean_angle"]
x = x.to("cuda" if torch.cuda.is_available() else "cpu")
angle = angle.to("cuda" if torch.cuda.is_available() else "cpu")
regen = model(x, -angle)
regen = regen.to("cpu")
loss += util.complexNN.complex_mse_loss(regen, y, power=True).item()
# shape: [batch_size, 4]
plot_regen = regen[:, :2]
plot_regen = plot_regen.view(plot_regen.shape[0], -1, 2)
plot_regen = plot_regen[:, plot_regen.shape[1] // 2, :]
data_out = torch.cat(
(
fiber_in,
fiber_out,
# fiber_out_noisy,
plot_regen,
timestamp.view(-1, 1),
),
dim=1,
)
datas.append(data_out)
data_out = torch.cat(datas, dim=0).numpy()
return data_out, loss
## File I/O
@staticmethod
def save_to_file(path: str, data: np.ndarray, **metadata: dict):
# create directory if it doesn't exist
path.parent.mkdir(parents=True, exist_ok=True)
with h5py.File(path, "w") as outfile:
outfile.create_dataset("data", data=data)
for key, value in metadata.items():
if isinstance(value, dict):
value = json.dumps(model_runner.convert_arrays(value))
outfile.attrs[key] = value
@staticmethod
def convert_arrays(dict_in):
"""
convert ndarrays in (nested) dict to lists
"""
dict_out = {}
for key, value in dict_in.items():
if isinstance(value, dict):
dict_out[key] = model_runner.convert_arrays(value)
elif isinstance(value, np.ndarray):
dict_out[key] = value.tolist()
else:
dict_out[key] = value
return dict_out
@staticmethod
def load_from_file(path: str):
with h5py.File(path, "r") as infile:
data = infile["data"][:]
metadata = {}
for key in infile.attrs.keys():
if isinstance(infile.attrs[key], (str, bytes, bytearray)):
try:
metadata[key] = json.loads(infile.attrs[key])
except json.JSONDecodeError:
metadata[key] = infile.attrs[key]
else:
metadata[key] = infile.attrs[key]
return data, metadata
## Utility functions
@staticmethod
def logrange(start, stop, num, endpoint=False):
lower, upper = np.log10((start, stop))
return np.logspace(lower, upper, num=num, endpoint=endpoint, base=10)
@staticmethod
def build_path(
*elements, parent_dir: str | Path | None = None, filetype="h5", timestamp: Literal["no", "front", "back"] = "no"
):
suffix = f".{filetype}" if not filetype.startswith(".") else filetype
if timestamp != "no":
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
if timestamp == "front":
elements = (ts, *elements)
elif timestamp == "back":
elements = (*elements, ts)
path = "_".join(elements)
path += suffix
if parent_dir is not None:
path = Path(parent_dir) / path
return path
@staticmethod
def update_range(min, max, n_steps, log):
if log:
range = model_runner.logrange(min, max, n_steps, endpoint=True)
else:
range = np.linspace(min, max, n_steps, endpoint=True)
return range
class model_evaluation_result:
def __init__(
self,
*,
length=None,
baudrate=None,
osnr=None,
model_path=None,
dataset_path=None,
loss=None,
sps=None,
**kwargs,
):
self.length = length
self.baudrate = baudrate
self.osnr = osnr
self.model_path = model_path
self.dataset_path = dataset_path
self.loss = loss
self.sps = sps
self.sers = None
self.bers = None
self.eye_stats = None
class evaluator:
def __init__(self, datasets: list[str]):
"""
datasets: iterable of dataset paths
data_dir: directory containing datasets
"""
self.datasets = datasets
self.results = []
def evaluate_datasets(self, plot=False):
## iterate over datasets
# load dataset
for dataset in self.datasets:
model, dataset_name = dataset.split("/")[-2:]
print(f"\nEvaluating model {model} with dataset {dataset_name}")
data, metadata = model_runner.load_from_file(dataset)
result = model_evaluation_result(**metadata)
data = self.prepare_data(data, sps=metadata["sps"])
try:
sym_x, sym_y = metadata["orig_symbols"]
except (TypeError, KeyError, ValueError):
sym_x, sym_y = None, None
self.evaluate_eye(data, result, title=dataset.split("/")[-1], plot=False)
self.evaluate_ser_ber(data, result, sym_x, sym_y)
print("BER:")
self.print_dict(result.bers["regen"])
print()
print("SER:")
self.print_dict(result.sers["regen"])
print()
self.results.append(result)
if plot:
plt.show()
def evaluate_eye(self, data, result, title=None, plot=False):
eye = util.eye_diagram.eye_diagram(
data,
channel_names=[
"fiber_in_x",
"fiber_in_y",
# "fiber_out_x",
# "fiber_out_y",
"fiber_out_x",
"fiber_out_y",
"regen_x",
"regen_y",
],
)
eye.analyse()
eye.plot(title=title or "Eye diagram", show=plot)
result.eye_stats = eye.eye_stats
return eye.eye_stats
...
def evaluate_ser_ber(self, data, result, sym_x=None, sym_y=None):
if result.eye_stats is None:
self.evaluate_eye(data, result)
symbols = []
sers = {"fiber_out": {"x": None, "y": None}, "regen": {"x": None, "y": None}}
bers = {"fiber_out": {"x": None, "y": None}, "regen": {"x": None, "y": None}}
for channel_data, stats in zip(data, result.eye_stats):
timestamps = channel_data[0]
dat = channel_data[1]
channel_name = stats["channel_name"]
if stats["success"]:
thresholds = stats["thresholds"]
time_midpoint = stats["time_midpoint"]
else:
if channel_name.endswith("x"):
thresholds = result.eye_stats[0]["thresholds"]
time_midpoint = result.eye_stats[0]["time_midpoint"]
elif channel_name.endswith("y"):
thresholds = result.eye_stats[1]["thresholds"]
time_midpoint = result.eye_stats[1]["time_midpoint"]
else:
levels = np.linspace(np.min(dat), np.max(dat), 4)
thresholds = util.eye_diagram.eye_diagram.calculate_thresholds(levels)
time_midpoint = 1.0
# time_offset = time_midpoint - 0.5
# # time_offset = 0
# index_offset = np.argmin(np.abs((timestamps - time_offset) % 1.0))
nos = len(timestamps) // result.sps
# idx = np.arange(index_offset, len(timestamps), result.sps).astype(int)
# if time_offset < 0:
# idx = np.insert(idx, 0, 0)
idx = list(range(0,len(timestamps),result.sps))
idx = idx[:nos]
data_sampled = dat[idx]
detected_symbols = self.detect_symbols(data_sampled, thresholds)
symbols.append({"channel_name": channel_name, "symbols": detected_symbols})
symbols_x_gt = sym_x or symbols[0]["symbols"]
symbols_y_gt = sym_y or symbols[1]["symbols"]
symbols_x_fiber_out = symbols[2]["symbols"]
symbols_y_fiber_out = symbols[3]["symbols"]
symbols_x_regen = symbols[4]["symbols"]
symbols_y_regen = symbols[5]["symbols"]
sers["fiber_out"]["x"], bers["fiber_out"]["x"] = self.calculate_ser_ber(symbols_x_gt, symbols_x_fiber_out)
sers["fiber_out"]["y"], bers["fiber_out"]["y"] = self.calculate_ser_ber(symbols_y_gt, symbols_y_fiber_out)
sers["regen"]["x"], bers["regen"]["x"] = self.calculate_ser_ber(symbols_x_gt, symbols_x_regen)
sers["regen"]["y"], bers["regen"]["y"] = self.calculate_ser_ber(symbols_y_gt, symbols_y_regen)
result.sers = sers
result.bers = bers
@staticmethod
def calculate_ser_ber(symbols_gt, symbols):
# levels = 4
# symbol difference -> bit error count
# |rx - tx| = 0 -> 0
# |rx - tx| = 1 -> 1
# |rx - tx| = 2 -> 2
# |rx - tx| = 3 -> 1
# assuming gray coding -> 0: 00, 1: 01, 2: 11, 3: 10
bec_map = {0: 0, 1: 1, 2: 2, 3: 1, np.nan: 2}
ser = {}
ber = {}
ser["n_symbols"] = len(symbols_gt)
ser["n_errors"] = np.sum(symbols != symbols_gt)
ser["total"] = float(ser["n_errors"] / ser["n_symbols"])
bec = np.vectorize(bec_map.get)(np.abs(symbols - symbols_gt))
bit_errors = np.sum(bec)
ber["n_bits"] = len(symbols_gt) * 2
ber["n_errors"] = bit_errors
ber["total"] = float(ber["n_errors"] / ber["n_bits"])
return ser, ber
@staticmethod
def print_dict(d: dict, indent=2, logarithmic=False, level=0):
for key, value in d.items():
if isinstance(value, dict):
print(f"{' ' * indent * level}{key}:")
evaluator.print_dict(value, indent=indent, logarithmic=logarithmic, level=level + 1)
else:
if isinstance(value, float):
if logarithmic:
if value == 0:
value = -np.inf
else:
value = np.log10(value)
print(f"{' ' * indent * level}{key}: {value:.2e}\t", end="")
else:
print(f"{' ' * indent * level}{key}: {value}\t", end="")
print()
@staticmethod
def detect_symbols(samples, thresholds=None):
thresholds = (1 / 6, 3 / 6, 5 / 6) if thresholds is None else thresholds
thresholds = (-np.inf, *thresholds, np.inf)
symbols = np.digitize(samples, thresholds) - 1
return symbols
@staticmethod
def prepare_data(data, sps=None):
data = data.transpose(1, 0)
timestamps = data[-1].real
data = data[:-1]
if sps is not None:
timestamps /= sps
# data = np.stack(
# (
# *data[0:2], # fiber_in_x, fiber_in_y
# # *data_[2:4], # fiber_out_x, fiber_out_y
# *data[4:6], # fiber_out_noisy_x, fiber_out_noisy_y
# *data[6:8], # regen_out_x, regen_out_y
# ),
# axis=0,
# )
data_eye = []
for channel_values in data:
channel_values = np.square(np.abs(channel_values))
data_eye.append(np.stack((timestamps, channel_values), axis=0))
data_eye = np.stack(data_eye, axis=0)
return data_eye
def generate_data(parameters, runner=None):
runner = runner or model_runner()
for param in parameters:
runner.update_parameter_range(*param)
runner.generate_iterations()
print(f"{runner.iter.length} parameter combinations")
runner.generate_datasets()
return runner
if __name__ == "__main__":
model_path = ".models/best_20250110_191149.tar" # D 17, OSNR 100, delta_beta 0.14, baud 10e9
parameters = (
# name, config keys, (min, max), n_steps, log
# ("D", ("fiber", "d"), (28,30), 3, False),
# ("S", ("fiber", "s"), (0, 0.058), 2, False),
("OSNR", ("signal", "osnr"), (20, 40), 5, False),
# ("PMD", ("fiber", "max_delta_beta"), (0, 0.28), 3, False),
# ("Baud", ("glova", "symbolrate"), (10e9, 100e9), 3, True),
)
datasets = None
results = None
# datasets = "tolerance_results/datasets/datasets_list_20250110_223337.h5"
results = "tolerance_results/datasets/results_list_20250110_232639.h5"
runner = model_runner()
# generate_data(parameters, runner)
if results is None:
if datasets is None:
generate_data(parameters, runner)
else:
runner.load_datasets_from_file(datasets)
print(f"{len(runner.datasets)} loaded")
runner.run_model_evaluation(model_path)
else:
runner.load_results_from_file(results)
# print(runner.parameters)
# print(runner.results)
eval = evaluator(runner.results)
eval.evaluate_datasets(plot=True)

View File

@@ -482,6 +482,15 @@ class Identity(nn.Module):
def forward(self, x):
return x
class phase_shift(nn.Module):
def __init__(self, size):
super(phase_shift, self).__init__()
self.size = size
self.phase = nn.Parameter(torch.rand(size))
def forward(self, x):
return x * torch.exp(1j*self.phase)
class PowRot(nn.Module):
def __init__(self, bias=False):
@@ -531,19 +540,19 @@ def angle_mse_loss(x: torch.Tensor, target: torch.Tensor):
class EOActivation(nn.Module):
def __init__(self, size=None):
# 10.1109/SiPhotonics60897.2024.10543376
# 10.1109/JSTQE.2019.2930455
super(EOActivation, self).__init__()
if size is None:
raise ValueError("Size must be specified")
self.size = size
self.alpha = nn.Parameter(torch.ones(size))
self.V_bias = nn.Parameter(torch.ones(size))
self.gain = nn.Parameter(torch.ones(size))
self.alpha = nn.Parameter(torch.rand(size))
self.V_bias = nn.Parameter(torch.rand(size))
self.gain = nn.Parameter(torch.rand(size))
# if bias:
# self.phase_bias = nn.Parameter(torch.zeros(size))
# else:
# self.register_buffer("phase_bias", torch.zeros(size))
self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
# self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
self.register_buffer("responsivity", torch.ones(size)*0.9)
self.register_buffer("V_pi", torch.ones(size)*3)
@@ -551,17 +560,17 @@ class EOActivation(nn.Module):
def reset_weights(self):
if "alpha" in self._parameters:
self.alpha.data = torch.ones(self.size)*0.5
self.alpha.data = torch.rand(self.size)
if "V_pi" in self._parameters:
self.V_pi.data = torch.ones(self.size)*3
self.V_pi.data = torch.rand(self.size)*3
if "V_bias" in self._parameters:
self.V_bias.data = torch.zeros(self.size)
self.V_bias.data = torch.randn(self.size)
if "gain" in self._parameters:
self.gain.data = torch.ones(self.size)
self.gain.data = torch.rand(self.size)
if "responsivity" in self._parameters:
self.responsivity.data = torch.ones(self.size)*0.9
if "bias" in self._parameters:
self.phase_bias.data = torch.zeros(self.size)
# if "bias" in self._parameters:
# self.phase_bias.data = torch.zeros(self.size)
def forward(self, x: torch.Tensor):
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
@@ -570,12 +579,11 @@ class EOActivation(nn.Module):
return (
1j
* torch.sqrt(1 - self.alpha)
* torch.exp(-0.5j * (intermediate + self.phase_bias))
* torch.exp(-0.5j * intermediate)
* torch.cos(0.5 * intermediate)
* x
)
class Pow(nn.Module):
"""
implements the activation function
@@ -716,6 +724,7 @@ __all__ = [
MZISingle,
EOActivation,
photodiode,
phase_shift,
# SaturableAbsorberLambertW,
# SaturableAbsorber,
# SpreadLayer,

View File

@@ -1,4 +1,5 @@
from pathlib import Path
import h5py
import torch
from torch.utils.data import Dataset
@@ -24,8 +25,22 @@ import multiprocessing as mp
# def __len__(self):
# return len(self.indices)
def load_from_file(datapath):
if str(datapath).endswith('.h5'):
symbols = None
with h5py.File(datapath, "r") as infile:
data = infile["data"][:]
try:
symbols = infile["symbols"][:]
except KeyError:
pass
else:
symbols = None
data = np.load(datapath)
return data, symbols
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=1, device=None, dtype=None):
filepath = Path(config_path)
filepath = filepath.parent.glob(filepath.name)
config = configparser.ConfigParser()
@@ -41,14 +56,20 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
if symbols is None:
symbols = int(config["glova"]["nos"]) - skipfirst
data = np.load(datapath)[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)]
data, orig_symbols = load_from_file(datapath)
data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)]
orig_symbols = orig_symbols[skipfirst:symbols+skipfirst]
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps))
if normalize:
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
a, b, c, d = np.square(data.T)
a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
data = np.sqrt(np.array([a, b, c, d]).T)
data *= np.sqrt(normalize)
# if normalize:
# # square gets normalized to 1, as the power is (proportional to) the square of the amplitude
# a, b, c, d = data.T
# a, b, c, d = a - np.min(np.abs(a)), b - np.min(np.abs(b)), c - np.min(np.abs(c)), d - np.min(np.abs(d))
# a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
# data = np.array([a, b, c, d]).T
if real:
data = np.abs(data)
@@ -59,7 +80,7 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
data = torch.tensor(data, device=device, dtype=dtype)
return data, config
return data, config, orig_symbols
def roll_along(arr, shifts, dim):
@@ -114,7 +135,8 @@ class FiberRegenerationDataset(Dataset):
dtype: torch.dtype = None,
real: bool = False,
device=None,
osnr: float = None,
# osnr: float|None = None,
polarisations = None,
randomise_polarisations: bool = False,
repeat_randoms: int = 1,
**kwargs,
@@ -151,36 +173,26 @@ class FiberRegenerationDataset(Dataset):
self.randomise_polarisations = randomise_polarisations
faux = kwargs.pop("faux", False)
if faux:
data_raw = np.array(
[[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)],
dtype=np.complex128,
)
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
timestamps = torch.arange(12800)
data_raw = torch.concatenate([data_raw, timestamps.reshape(-1, 1)], axis=-1)
self.config = {
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
"glova": {"sps": 128},
}
else:
data_raw = None
self.config = None
files = []
self.orig_symbols = None
for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]:
data, config = load_data(
data, config, orig_syms = load_data(
file_path,
skipfirst=drop_first,
symbols=kwargs.get("num_symbols", None),
real=real,
normalize=True,
normalize=1000,
device=device,
dtype=dtype,
)
if orig_syms is not None:
if self.orig_symbols is None:
self.orig_symbols = orig_syms
else:
self.orig_symbols = np.concat((self.orig_symbols, orig_syms), axis=-1)
if data_raw is None:
data_raw = data
else:
@@ -193,23 +205,18 @@ class FiberRegenerationDataset(Dataset):
self.config["data"]["file"] = str(files)
# if polarisations is not None:
# self.angles = torch.tensor(polarisations).repeat(len(data_raw), 1)
# for i, angle in enumerate(torch.tensor(np.array(polarisations))):
# data_raw_copy = data_raw.clone()
# if angle == 0:
# continue
# sine = torch.sin(angle)
# cosine = torch.cos(angle)
# data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine
# data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine
# if i == 0:
# data_raw = data_raw_copy
# else:
# data_raw = torch.cat([data_raw, data_raw_copy], dim=0)
# data_raw_clone = data_raw.clone()
# # rotate the polarisation by 180 degrees
# data_raw_clone[2, :] *= -1
# data_raw_clone[3, :] *= -1
# data_raw = torch.cat([data_raw, data_raw_clone], dim=0)
self.polarisations = bool(polarisations)
self.device = data_raw.device
self.samples_per_symbol = int(self.config["glova"]["sps"])
# self.num_symbols = int(self.config["glova"]["nos"])
self.samples_per_slice = int(symbols * self.samples_per_symbol)
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
@@ -290,6 +297,34 @@ class FiberRegenerationDataset(Dataset):
fiber_in = torch.cat([fiber_in, timestamps.unsqueeze(0)], dim=0)
fiber_out = torch.cat([fiber_out, timestamps.unsqueeze(0)], dim=0)
# fiber_out: [E_out_x, E_out_y, timestamps]
# add noise related to amplification necessary due to splitting of the signal
gain_lin = output_dim*2
edfa_nf = float(self.config["signal"]["edfa_nf"])
nf_lin = 10**(edfa_nf/10)
f0 = float(self.config["glova"]["f0"])
noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9
noise = torch.randn_like(fiber_out[:2, :])
noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1))
noise = noise * torch.sqrt(noise_add / noise_power)
fiber_out[:2, :] += noise
# if osnr is None:
# noisy = fiber_out[:2, :]
# else:
# noisy = self.add_noise(fiber_out[:2, :], osnr)
# fiber_out = torch.cat([fiber_out, noisy], dim=0)
# fiber_out: [E_out_x, E_out_y, timestamps, E_out_x_noisy, E_out_y_noisy]
if repeat_randoms > 1:
fiber_in = fiber_in.repeat(1, 1, repeat_randoms)
fiber_out = fiber_out.repeat(1, 1, repeat_randoms)
@@ -298,27 +333,33 @@ class FiberRegenerationDataset(Dataset):
repeat_randoms = 1
if self.randomise_polarisations:
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms), 2) * torch.pi
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms, device=fiber_out.device), 2) * torch.pi
# start_angle = torch.rand(1) * 2 * torch.pi
# angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
# self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi
else:
angles = torch.zeros(data_raw.shape[-1])
angles = torch.zeros(data_raw.shape[-1], device=fiber_out.device)
sin = torch.sin(angles)
cos = torch.cos(angles)
rot = torch.stack([torch.stack([cos, -sin], dim=1), torch.stack([sin, cos], dim=1)], dim=2)
data_rot = torch.bmm(fiber_out[:2, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
# data_rot_noisy = torch.bmm(fiber_out[3:5, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
fiber_out = torch.cat((fiber_out, data_rot), dim=0)
fiber_out = torch.cat([fiber_out, angles.unsqueeze(0)], dim=0)
if osnr is not None:
popt = torch.mean(fiber_out[:2, :, :].abs().flatten(), dim=-1)
noise = torch.randn_like(fiber_out[:2, :, :])
pn = torch.mean(noise.abs().flatten(), dim=-1)
noise = noise * (popt / pn) * 10 ** (-osnr / 20)
fiber_out[:2, :, :] = torch.add(fiber_out[:2, :, :], noise)
# fiber_in:
# 0 E_in_x,
# 1 E_in_y,
# 2 timestamps
# fiber_out:
# 0 E_out_x,
# 1 E_out_y,
# 2 timestamps,
# 3 E_out_x_rot,
# 4 E_out_y_rot,
# 5 angle
@@ -350,6 +391,22 @@ class FiberRegenerationDataset(Dataset):
def __len__(self):
return self.fiber_in.shape[0]
def add_noise(self, data, osnr):
osnr_lin = 10**(osnr/10)
popt = torch.mean(data.abs().square().squeeze(), dim=-1)
noise = torch.randn_like(data)
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
mult = torch.sqrt(popt/(pn*osnr_lin))
mult = mult * torch.eye(popt.shape[0], device=mult.device)
mult = mult.to(dtype=noise.dtype)
noise = mult @ noise
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
noisy = data + noise
return noisy
def __getitem__(self, idx):
if isinstance(idx, slice):
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
@@ -357,14 +414,19 @@ class FiberRegenerationDataset(Dataset):
# fiber in: [E_in_x, E_in_y, timestamps]
# fiber out: [E_out_x, E_out_y, timestamps, E_out_x_rot, E_out_y_rot, angle]
# if self.polarisations:
output_dim = self.output_dim // 2
self.output_dim = output_dim * 2
fiber_in = self.fiber_in[idx].squeeze()
fiber_out = self.fiber_out[idx].squeeze()
fiber_in = fiber_in[..., : fiber_in.shape[-1] // self.output_dim * self.output_dim]
fiber_out = fiber_out[..., : fiber_out.shape[-1] // self.output_dim * self.output_dim]
fiber_in = fiber_in[..., : fiber_in.shape[-1] // output_dim * output_dim]
fiber_out = fiber_out[..., : fiber_out.shape[-1] // output_dim * output_dim]
fiber_in = fiber_in.view(fiber_in.shape[0], output_dim, -1)
fiber_out = fiber_out.view(fiber_out.shape[0], output_dim, -1)
fiber_in = fiber_in.view(fiber_in.shape[0], self.output_dim, -1)
fiber_out = fiber_out.view(fiber_out.shape[0], self.output_dim, -1)
# data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
@@ -372,11 +434,36 @@ class FiberRegenerationDataset(Dataset):
# angle = self.angles[idx]
center_angle = fiber_out[5, self.output_dim // 2, 0]
# fiber_in:
# 0 E_in_x,
# 1 E_in_y,
# 2 timestamps
# fiber_out:
# 0 E_out_x,
# 1 E_out_y,
# 2 timestamps,
# 3 E_out_x_rot,
# 4 E_out_y_rot,
# 5 angle
center_angle = fiber_out[0, output_dim // 2, 0]
angles = fiber_out[5, :, 0]
plot_data = fiber_out[:2, self.output_dim // 2, 0].detach().clone()
plot_data_rot = fiber_out[3:5, self.output_dim // 2, 0].detach().clone()
data = fiber_out[3:5, :, 0]
plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone()
plot_data_rot = fiber_out[3:5, output_dim // 2, 0].detach().clone()
data = fiber_out[0:2, :, 0]
# fiber_out_plot_clean = fiber_out[:2, output_dim // 2, 0].detach().clone()
# if self.polarisations:
# rot = int(np.random.randint(2)*2-1)
# pol_flipped_data[0:1, :] = rot*data[0, :]
# pol_flipped_data[1, :] = rot*data[1, :]
# plot_data_rot[0] = rot*plot_data_rot[0]
# plot_data_rot[1] = rot*plot_data_rot[1]
# center_angle = center_angle + (rot - 1) * torch.pi/2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
# angles = angles + (rot - 1) * torch.pi/2
# if self.randomise_polarisations:
# data = data.mT
@@ -389,16 +476,27 @@ class FiberRegenerationDataset(Dataset):
# angle = torch.zeros_like(angle)
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter)
angle_data = fiber_out[:2, :, :].reshape(2, -1).mean(dim=1).repeat(1, self.output_dim)
angle_data2 = self.complex_max(fiber_out[:2, :, :].reshape(2, -1)).repeat(1, self.output_dim)
# angle_data = fiber_out[:2, :, :].reshape(2, -1).mean(dim=1).repeat(1, output_dim)
# angle_data2 = self.complex_max(fiber_out[:2, :, :].reshape(2, -1)).repeat(1, output_dim)
# sop = self.polarimeter(plot_data)
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
# angle = data_slice[1, 3, self.output_dim // 2, 0].real
target = fiber_in[:2, self.output_dim // 2, 0]
plot_target = fiber_in[:2, self.output_dim // 2, 0].detach().clone()
target_timestamp = fiber_in[2, self.output_dim // 2, 0].real
target = fiber_in[:2, output_dim // 2, 0]
plot_target = fiber_in[:2, output_dim // 2, 0].detach().clone()
target_timestamp = fiber_in[2, output_dim // 2, 0].real
...
if self.polarisations:
rot = int(np.random.randint(2)*2-1)
data = rot*data
target = rot*target
plot_data_rot = rot*plot_data_rot
center_angle = center_angle + (rot - 1) * torch.pi/2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
angles = angles + (rot - 1) * torch.pi/2
pol_flipped_data = -data
pol_flipped_target = -target
# data_timestamps = data[-1,:].real
# data = data[:-1, :]
# target_timestamp = target[-1].real
@@ -407,13 +505,15 @@ class FiberRegenerationDataset(Dataset):
# transpose to interleave the x and y data in the output tensor
data = data.transpose(0, 1).flatten().squeeze()
angle_data = angle_data.transpose(0, 1).flatten().squeeze()
angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
pol_flipped_data = pol_flipped_data.transpose(0, 1).flatten().squeeze()
# angle_data = angle_data.transpose(0, 1).flatten().squeeze()
# angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
center_angle = center_angle.flatten().squeeze()
angles = angles.flatten().squeeze()
# data_timestamps = data_timestamps.flatten().squeeze()
# target = target.transpose(0,1).flatten().squeeze()
target = target.flatten().squeeze()
pol_flipped_target = pol_flipped_target.flatten().squeeze()
target_timestamp = target_timestamp.flatten().squeeze()
plot_target = plot_target.flatten().squeeze()
plot_data = plot_data.flatten().squeeze()
@@ -421,17 +521,22 @@ class FiberRegenerationDataset(Dataset):
return {
"x": data,
"x_flipped": pol_flipped_data,
"x_stacked": torch.cat([data, pol_flipped_data], dim=-1),
"y": target,
"center_angle": center_angle,
"angles": angles,
"y_flipped": pol_flipped_target,
"y_stacked": torch.cat([target, pol_flipped_target], dim=-1),
# "center_angle": center_angle,
# "angles": angles,
"mean_angle": angles.mean(),
# "sop": sop,
"angle_data": angle_data,
"angle_data2": angle_data2,
# "angle_data": angle_data,
# "angle_data2": angle_data2,
"timestamp": target_timestamp,
"plot_target": plot_target,
"plot_data": plot_data,
"plot_data_rot": plot_data_rot,
# "plot_clean": fiber_out_plot_clean,
}
def complex_max(self, data, dim=-1):

View File

@@ -82,7 +82,6 @@ class eye_diagram:
self.vertical_bins = vertical_bins
self.multi_threaded = multithreaded
self.eye_built = False
self.analyse()
def generate_eye_data(self):
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
@@ -126,6 +125,7 @@ class eye_diagram:
rows = int(np.ceil(self.channels / cols))
fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False)
fig.suptitle(title)
fig.tight_layout()
ax = np.atleast_1d(ax).transpose().flatten()
for i in range(self.channels):
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}")
@@ -147,19 +147,21 @@ class eye_diagram:
yspan = ymax - ymin
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
if stats and self.eye_stats[i]["success"]:
# add min_area above the plot
ax[i].annotate(
f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
xy=(0.05, ymax + 0.05 * yspan),
# xycoords="axes fraction",
ha="left",
va="center",
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
)
# # add min_area above the plot
# ax[i].annotate(
# f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
# xy=(0.05, ymax + 0.05 * yspan),
# # xycoords="axes fraction",
# ha="left",
# va="center",
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
# )
if all_stats:
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
ax[i].set_yticks(self.eye_stats[i]["levels"])
y_ticks = (*self.eye_stats[i]["levels"],*self.eye_stats[i]["thresholds"])
# y_ticks = np.sort(y_ticks)
ax[i].set_yticks(y_ticks)
# add arrows for amplitudes
for j in range(len(self.eye_stats[i]["amplitudes"])):
ax[i].annotate(
@@ -193,35 +195,35 @@ class eye_diagram:
except (ValueError, IndexError):
pass
# add arrows for eye widths
for j in range(len(self.eye_stats[i]["widths"])):
try:
left = np.max(self.eye_stats[i]["time_clusters"][j][0])
right = np.min(self.eye_stats[i]["time_clusters"][j][1])
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
# for j in range(len(self.eye_stats[i]["widths"])):
# try:
# left = np.max(self.eye_stats[i]["time_clusters"][j][0])
# right = np.min(self.eye_stats[i]["time_clusters"][j][1])
# vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
ax[i].annotate(
"",
xy=(left, vertical),
xytext=(right, vertical),
arrowprops=dict(arrowstyle="<->", facecolor="black"),
)
ax[i].annotate(
f"{self.eye_stats[i]['widths'][j]:.2e}",
xy=((left + right) / 2 - 0.15, vertical + 0.01),
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
)
except (ValueError, IndexError):
pass
# ax[i].annotate(
# "",
# xy=(left, vertical),
# xytext=(right, vertical),
# arrowprops=dict(arrowstyle="<->", facecolor="black"),
# )
# ax[i].annotate(
# f"{self.eye_stats[i]['widths'][j]:.2e}",
# xy=((left + right) / 2 - 0.15, vertical + 0.01),
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
# )
# except (ValueError, IndexError):
# pass
# add area
for j in range(len(self.eye_stats[i]["areas"])):
horizontal = self.eye_stats[i]["time_midpoint"]
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
ax[i].annotate(
f"{self.eye_stats[i]['areas'][j]:.2e}",
xy=(horizontal + 0.035, vertical - 0.07),
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
)
# # add area
# for j in range(len(self.eye_stats[i]["areas"])):
# horizontal = self.eye_stats[i]["time_midpoint"]
# vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
# ax[i].annotate(
# f"{self.eye_stats[i]['areas'][j]:.2e}",
# xy=(horizontal + 0.035, vertical - 0.07),
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
# )
fig.tight_layout()
@@ -229,6 +231,12 @@ class eye_diagram:
plt.show()
return fig
@staticmethod
def calculate_thresholds(levels):
ret = np.cumsum(levels, dtype=float)
ret[2:] = ret[2:] - ret[:-2]
return ret[1:]/2
def analyse_single(self, data, index):
warnings.filterwarnings("error")
eye_stats = {}
@@ -238,12 +246,15 @@ class eye_diagram:
time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels)
eye_stats["time_midpoint"] = (time_bounds[0] + time_bounds[1]) / 2
eye_stats["time_midpoint_calc"] = (time_bounds[0] + time_bounds[1]) / 2
eye_stats["time_midpoint"] = 1.0
eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels(
data, approx_levels, time_bounds
)
eye_stats["thresholds"] = self.calculate_thresholds(eye_stats["levels"])
eye_stats["amplitudes"] = np.diff(eye_stats["levels"])
eye_stats["heights"] = eye_diagram.calculate_eye_heights(
@@ -260,22 +271,23 @@ class eye_diagram:
# if not (np.max(eye_stats['time_clusters'][j][0]) < eye_stats["time_midpoint"] < np.min(eye_stats['time_clusters'][j][1])):
# raise ValueError
eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"]
eye_stats["mean_area"] = np.mean(eye_stats["areas"])
eye_stats["min_area"] = np.min(eye_stats["areas"])
# eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"]
# eye_stats["mean_area"] = np.mean(eye_stats["areas"])
# eye_stats["min_area"] = np.min(eye_stats["areas"])
eye_stats["success"] = True
except (RuntimeWarning, UserWarning, ValueError):
eye_stats["success"] = False
eye_stats["time_midpoint"] = 0
eye_stats["levels"] = np.zeros(self.n_levels)
eye_stats["amplitude_clusters"] = []
eye_stats["amplitudes"] = np.zeros(self.n_levels - 1)
eye_stats["heights"] = np.zeros(self.n_levels - 1)
eye_stats["widths"] = np.zeros(self.n_levels - 1)
eye_stats["areas"] = np.zeros(self.n_levels - 1)
eye_stats["mean_area"] = 0
eye_stats["min_area"] = 0
eye_stats["time_midpoint"] = None
eye_stats["levels"] = None
eye_stats["thresholds"] = None
eye_stats["amplitude_clusters"] = None
eye_stats["amplitudes"] = None
eye_stats["heights"] = None
eye_stats["widths"] = None
# eye_stats["areas"] = np.zeros(self.n_levels - 1)
# eye_stats["mean_area"] = 0
# eye_stats["min_area"] = 0
warnings.resetwarnings()
return eye_stats
@@ -441,7 +453,8 @@ if __name__ == "__main__":
data = generate_sample_data(length, noise=0.005)
eye = eye_diagram(data, horizontal_bins=256, vertical_bins=256)
attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths", "area", "mean_area", "min_area")
eye.analyse()
attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area")
for i, channel in enumerate(eye.eye_stats):
print(f"Channel {i}")
print_data = {attr: channel[attr] for attr in attrs}

View File

@@ -1,6 +1,9 @@
import matplotlib.pyplot as plt
import numpy as np
from .datasets import load_data
if __name__ == "__main__":
from datasets import load_data
else:
from .datasets import load_data
def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0, width=2, alpha=None, complex=False, show=True):
"""Plot an eye diagram for the data given by filepath.
@@ -20,6 +23,7 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0
raise ValueError("Either path or data and sps must be given.")
if path is not None:
data, config = load_data(path, skipfirst, symbols)
data = data.detach().cpu().numpy()[:, :4]
sps = int(config["glova"]["sps"])
if sps is None:
raise ValueError("sps not set.")
@@ -71,3 +75,6 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0
plt.show()
return fig
if __name__ == "__main__":
eye(path="data/20241229-163838-128-16384-50000-0-0.2-16.8-0.058-PAM4-0-0.16.ini", symbols=1000, width=2, alpha=0.1, complex=False)