model robustness testing
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -163,3 +163,4 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
tolerance_results/datasets/*
|
||||
|
||||
37
notes/models.md
Normal file
37
notes/models.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# models
|
||||
|
||||
## no polarisation flipping
|
||||
|
||||
```py
|
||||
config_path="data/20241229-163*-128-16384-50000-*.ini"
|
||||
model=".models/best_20241230_011907.tar"
|
||||
```
|
||||
|
||||
```py
|
||||
config_path="data/20241229-163*-128-16384-80000-*.ini"
|
||||
model=".models/best_20241230_103752.tar"
|
||||
```
|
||||
|
||||
```py
|
||||
config_path="data/20241229-163*-128-16384-100000-*.ini"
|
||||
model=".models/best_20241230_164534.tar"
|
||||
```
|
||||
|
||||
## with polarisation flipping
|
||||
|
||||
polarisation flipping: signal is randomly rotated by 180°. polarization rotation can be detected by adding a tone on one of the polarisations, but only to mod 180° with a direct detection setup. the randomly flipped signal should allow the network to hopefully learn to compensate for dispersion, pmd independently from the polarization rot. the training data includes the flipped signal as well, but no indication if the polarisation is flipped.
|
||||
|
||||
```py
|
||||
config_path="data/20241229-163*-128-16384-50000-*.ini"
|
||||
model=".models/best_20241231_000328.tar"
|
||||
```
|
||||
|
||||
```py
|
||||
config_path="data/20241229-163*-128-16384-80000-*.ini"
|
||||
model=".models/best_20241231_163614.tar"
|
||||
```
|
||||
|
||||
```py
|
||||
config_path="data/20241229-163*-128-16384-100000-*.ini"
|
||||
model=".models/best_20241231_170532.tar"
|
||||
```
|
||||
@@ -124,7 +124,7 @@ class regenerator(Module):
|
||||
parametrizations: list[dict] = None,
|
||||
dtype=torch.float64,
|
||||
dropout_prob=0.01,
|
||||
scale_layers=False,
|
||||
prescale=1,
|
||||
rotate=False,
|
||||
):
|
||||
super(regenerator, self).__init__()
|
||||
@@ -134,15 +134,14 @@ class regenerator(Module):
|
||||
act_func_kwargs = act_func_kwargs or {}
|
||||
|
||||
self.rotation = rotate
|
||||
self.prescale = prescale
|
||||
|
||||
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
|
||||
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob)
|
||||
|
||||
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
|
||||
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob):
|
||||
for i in range(0, self._n_hidden_layers):
|
||||
self.add_module(f"layer_{i}", Sequential())
|
||||
|
||||
if scale_layers:
|
||||
self.get_submodule(f"layer_{i}").add_module("scale", Scale(dims[i]))
|
||||
|
||||
module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs)
|
||||
self.get_submodule(f"layer_{i}").add_module("ONN", module)
|
||||
@@ -156,8 +155,8 @@ class regenerator(Module):
|
||||
|
||||
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
|
||||
|
||||
if scale_layers:
|
||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
|
||||
# if scale_layers:
|
||||
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
|
||||
|
||||
module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs)
|
||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module)
|
||||
@@ -200,6 +199,7 @@ class regenerator(Module):
|
||||
return powers
|
||||
|
||||
def forward(self, x, angle=None, pre_rot=False, trace_powers=False):
|
||||
x = x * self.prescale
|
||||
powers = self._trace_powers(trace_powers, x)
|
||||
# x = self.layer_0(x)
|
||||
# powers = self._trace_powers(trace_powers, x, powers)
|
||||
|
||||
@@ -683,7 +683,7 @@ class RegenerationTrainer:
|
||||
|
||||
def define_model(self, model_kwargs=None):
|
||||
if self.resume:
|
||||
model_kwargs = self.checkpoint_dict["model_kwargs"]
|
||||
model_kwargs = None
|
||||
else:
|
||||
model_kwargs = model_kwargs
|
||||
|
||||
@@ -692,6 +692,14 @@ class RegenerationTrainer:
|
||||
|
||||
input_dim = 2 * self.data_settings.output_size
|
||||
|
||||
# if self.data_settings.polarisations:
|
||||
# input_dim *= 2
|
||||
|
||||
output_dim = self.model_settings.output_dim
|
||||
|
||||
# if self.data_settings.polarisations:
|
||||
output_dim *= 2
|
||||
|
||||
dtype = getattr(torch, self.data_settings.dtype)
|
||||
|
||||
afunc = getattr(util.complexNN, self.model_settings.model_activation_func)
|
||||
@@ -703,7 +711,7 @@ class RegenerationTrainer:
|
||||
hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)]
|
||||
|
||||
self.model_kwargs = {
|
||||
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
|
||||
"dims": (input_dim, *hidden_dims, output_dim),
|
||||
"layer_function": layer_func,
|
||||
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
|
||||
"act_function": afunc,
|
||||
@@ -711,7 +719,7 @@ class RegenerationTrainer:
|
||||
"parametrizations": layer_parametrizations,
|
||||
"dtype": dtype,
|
||||
"dropout_prob": self.model_settings.dropout_prob,
|
||||
"scale_layers": self.model_settings.scale,
|
||||
"prescale": self.model_settings.scale,
|
||||
}
|
||||
else:
|
||||
self.model_kwargs = model_kwargs
|
||||
@@ -745,11 +753,12 @@ class RegenerationTrainer:
|
||||
num_symbols = None
|
||||
config_path = self.data_settings.config_path
|
||||
randomise_polarisations = self.data_settings.randomise_polarisations
|
||||
polarisations = self.data_settings.polarisations
|
||||
osnr = self.data_settings.osnr
|
||||
if override is not None:
|
||||
num_symbols = override.get("num_symbols", None)
|
||||
config_path = override.get("config_path", config_path)
|
||||
# polarisations = override.get("polarisations", polarisations)
|
||||
polarisations = override.get("polarisations", polarisations)
|
||||
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
|
||||
# get dataset
|
||||
dataset = FiberRegenerationDataset(
|
||||
@@ -763,6 +772,7 @@ class RegenerationTrainer:
|
||||
real=not dtype.is_complex,
|
||||
num_symbols=num_symbols,
|
||||
randomise_polarisations=randomise_polarisations,
|
||||
polarisations=polarisations,
|
||||
osnr = osnr,
|
||||
)
|
||||
|
||||
@@ -832,17 +842,19 @@ class RegenerationTrainer:
|
||||
running_loss = 0.0
|
||||
self.model.train()
|
||||
loader_len = len(train_loader)
|
||||
x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
||||
y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
||||
for batch_idx, batch in enumerate(train_loader):
|
||||
x = batch["x"]
|
||||
y = batch["y"]
|
||||
angles = batch["mean_angle"]
|
||||
x = batch[x_key]
|
||||
y = batch[y_key]
|
||||
angle = batch["mean_angle"]
|
||||
self.model.zero_grad(set_to_none=True)
|
||||
x, y, angles = (
|
||||
x, y, angle = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
angles.to(self.pytorch_settings.device),
|
||||
angle.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = self.model(x, -angles)
|
||||
y_pred = self.model(x, -angle)
|
||||
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||
loss_value = loss.item()
|
||||
loss.backward()
|
||||
@@ -886,17 +898,19 @@ class RegenerationTrainer:
|
||||
|
||||
self.model.eval()
|
||||
running_error = 0
|
||||
x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
||||
y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
||||
with torch.no_grad():
|
||||
for _, batch in enumerate(valid_loader):
|
||||
x = batch["x"]
|
||||
y = batch["y"]
|
||||
angles = batch["mean_angle"]
|
||||
x, y, angles = (
|
||||
x = batch[x_key]
|
||||
y = batch[y_key]
|
||||
angle = batch["mean_angle"]
|
||||
x, y, angle = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
angles.to(self.pytorch_settings.device),
|
||||
angle.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = self.model(x, -angles)
|
||||
y_pred = self.model(x, -angle)
|
||||
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||
error_value = error.item()
|
||||
running_error += error_value
|
||||
@@ -953,15 +967,17 @@ class RegenerationTrainer:
|
||||
regen = []
|
||||
timestamps = []
|
||||
angles = []
|
||||
x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
||||
y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
||||
|
||||
with torch.no_grad():
|
||||
model = model.to(self.pytorch_settings.device)
|
||||
for batch in loader:
|
||||
x = batch["x"]
|
||||
y = batch["y"]
|
||||
x = batch[x_key]
|
||||
y = batch[y_key]
|
||||
plot_target = batch["plot_target"]
|
||||
angle = batch["mean_angle"]
|
||||
center_angle = batch["center_angle"]
|
||||
# center_angle = batch["center_angle"]
|
||||
timestamp = batch["timestamp"]
|
||||
plot_data = batch["plot_data"]
|
||||
plot_data_rot = batch["plot_data_rot"]
|
||||
@@ -971,14 +987,16 @@ class RegenerationTrainer:
|
||||
angle.to(self.pytorch_settings.device),
|
||||
)
|
||||
if trace_powers:
|
||||
y_pred, powers = model(x, angle, True).cpu()
|
||||
y_pred, powers = model(x, -angle, True).cpu()
|
||||
else:
|
||||
y_pred = model(x, angle).cpu()
|
||||
y_pred = model(x, -angle).cpu()
|
||||
# x = x.cpu()
|
||||
# y = y.cpu()
|
||||
# if self.data_settings.polarisations:
|
||||
y_pred = y_pred[:, :2]
|
||||
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||
y_pred = y_pred[:, y_pred.shape[1]//2, :]
|
||||
y = y.view(y.shape[0], -1, 2)
|
||||
# y = y.view(y.shape[0], -1, 2)
|
||||
# plot_data = plot_data.view(plot_data.shape[0], -1, 2)
|
||||
# c = torch.cos(-angle).cpu()
|
||||
# s = torch.sin(-angle).cpu()
|
||||
@@ -996,7 +1014,7 @@ class RegenerationTrainer:
|
||||
fiber_in.append(plot_target.squeeze())
|
||||
regen.append(y_pred.squeeze())
|
||||
timestamps.append(timestamp.squeeze())
|
||||
angles.append(center_angle.squeeze())
|
||||
angles.append(angle.squeeze())
|
||||
|
||||
fiber_out = torch.vstack(fiber_out).cpu()
|
||||
fiber_out_rot = torch.vstack(fiber_out_rot).cpu()
|
||||
@@ -1352,7 +1370,8 @@ class RegenerationTrainer:
|
||||
"num_symbols": self.pytorch_settings.batchsize,
|
||||
"config_path": config_path,
|
||||
"shuffle": False,
|
||||
"polarisations": (np.random.rand(1) * np.pi * 2,),
|
||||
# "polarisations": (np.random.rand(1) * np.pi * 2,),
|
||||
"polarisations": self.data_settings.polarisations,
|
||||
"randomise_polarisation": self.data_settings.randomise_polarisations,
|
||||
}
|
||||
)
|
||||
@@ -1366,7 +1385,7 @@ class RegenerationTrainer:
|
||||
fiber_out_rot = fiber_out_rot.view(-1, 2)
|
||||
angles = angles.view(-1, 1)
|
||||
angles = angles.real
|
||||
angles = torch.fmod(angles, 2 * torch.pi)
|
||||
angles = torch.fmod(angles, 2*torch.pi)
|
||||
angles = torch.div(angles, 2*torch.pi)
|
||||
angles = torch.repeat_interleave(angles, 2, dim=1)
|
||||
|
||||
|
||||
253
src/single-core-regen/plot_model.py
Normal file
253
src/single-core-regen/plot_model.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import os
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import util
|
||||
from hypertraining.settings import GlobalSettings, DataSettings, ModelSettings, OptimizerSettings, PytorchSettings
|
||||
from hypertraining import models
|
||||
|
||||
# def move_to_location_in_size(array, location, size):
|
||||
# array_x, array_y = array.shape
|
||||
# location_x, location_y = location
|
||||
# size_x, size_y = size
|
||||
|
||||
# left = location_x
|
||||
# right = size_x - array_x - location_x
|
||||
|
||||
# top = location_y
|
||||
# bottom = size_y - array_y - location_y
|
||||
|
||||
# return np.pad(
|
||||
# array,
|
||||
# (
|
||||
# (left, right),
|
||||
# (top, bottom),
|
||||
# ),
|
||||
# constant_values=(-np.inf, -np.inf),
|
||||
# )
|
||||
|
||||
|
||||
def pad_to_size(array, size):
|
||||
if not hasattr(size, "__len__"):
|
||||
size = (size, size)
|
||||
|
||||
left = (
|
||||
(size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0
|
||||
)
|
||||
right = (
|
||||
(size[0] - array.shape[0]) // 2 if size[0] is not None else 0
|
||||
)
|
||||
top = (
|
||||
(size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0
|
||||
)
|
||||
bottom = (
|
||||
(size[1] - array.shape[1]) // 2 if size[1] is not None else 0
|
||||
)
|
||||
|
||||
array: np.ndarray = array
|
||||
if array.ndim == 2:
|
||||
return np.pad(
|
||||
array,
|
||||
(
|
||||
(left, right),
|
||||
(top, bottom),
|
||||
),
|
||||
constant_values=(np.nan, np.nan),
|
||||
)
|
||||
elif array.ndim == 3:
|
||||
return np.pad(
|
||||
array,
|
||||
(
|
||||
(left, right),
|
||||
(top, bottom),
|
||||
(0,0)
|
||||
),
|
||||
constant_values=(np.nan, np.nan),
|
||||
)
|
||||
|
||||
def model_plot(model_path):
|
||||
torch.serialization.add_safe_globals([
|
||||
*util.complexNN.__all__,
|
||||
GlobalSettings,
|
||||
DataSettings,
|
||||
ModelSettings,
|
||||
OptimizerSettings,
|
||||
PytorchSettings,
|
||||
models.regenerator,
|
||||
torch.nn.utils.parametrizations.orthogonal,
|
||||
])
|
||||
checkpoint_dict = torch.load(model_path, weights_only=True)
|
||||
|
||||
dims = checkpoint_dict["model_kwargs"].pop("dims")
|
||||
|
||||
model = models.regenerator(*dims, **checkpoint_dict["model_kwargs"])
|
||||
model.load_state_dict(checkpoint_dict["model_state_dict"])
|
||||
|
||||
model_params = []
|
||||
plots = []
|
||||
max_size = np.max(dims)
|
||||
# max_act_size = np.max(dims[1:])
|
||||
|
||||
angles = [None, None]
|
||||
weights = [None, None]
|
||||
|
||||
for num, (layer_name, layer) in enumerate(model.named_children()):
|
||||
# each layer contains an "ONN" layer and an "activation" layer
|
||||
# activation layer is approximately the same for all layers and nodes -> rotation by 90 degrees
|
||||
onn_weights = layer.ONN.weight.T
|
||||
onn_weights = onn_weights.detach().cpu().numpy()
|
||||
onn_values = np.abs(onn_weights).real
|
||||
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
|
||||
|
||||
|
||||
act = layer.activation
|
||||
|
||||
act_values = np.ones((act.size, 1))
|
||||
|
||||
act_values = np.nan * act_values
|
||||
|
||||
act_angles = act.phase.unsqueeze(-1).detach().cpu().numpy()
|
||||
...
|
||||
# act_phi_bias = torch.pi * act.V_bias / (act.V_pi + 1e-8)
|
||||
# act_phi_gain = torch.pi * (act.alpha * act.gain * act.responsivity) / (act.V_pi + 1e-8)
|
||||
# xs = (0.01, 0.1, 1)
|
||||
|
||||
# act_values = np.zeros((act.size, len(xs)*2))
|
||||
# act_angles = np.zeros((act.size, len(xs)*2))
|
||||
|
||||
# act_values[:,:] = np.nan
|
||||
# act_angles[:,:] = np.nan
|
||||
|
||||
# for xi, x in enumerate(xs):
|
||||
# phi_intermediate = act_phi_gain * x**2 + act_phi_bias
|
||||
|
||||
# act_resulting_gain = (
|
||||
# 1j
|
||||
# * torch.sqrt(1-act.alpha)
|
||||
# * torch.exp(-0.5j * phi_intermediate)
|
||||
# * torch.cos(0.5 * phi_intermediate)
|
||||
# * x
|
||||
# )
|
||||
|
||||
# act_resulting_gain = act_resulting_gain.detach().cpu().numpy()
|
||||
# act_values[:, xi*2] = np.abs(act_resulting_gain).real
|
||||
# act_angles[:, xi*2] = np.mod(np.angle(act_resulting_gain), 2*np.pi).real
|
||||
|
||||
|
||||
|
||||
# if angles[0] is None or angles[0] > np.min(onn_angles.flatten()):
|
||||
# angles[0] = np.min(onn_angles.flatten())
|
||||
# if angles[1] is None or angles[1] < np.max(onn_angles.flatten()):
|
||||
# angles[1] = np.max(onn_angles.flatten())
|
||||
# if weights[0] is None or weights[0] > np.min(onn_weights.flatten()):
|
||||
# weights[0] = np.min(onn_weights.flatten())
|
||||
# if weights[1] is None or weights[1] < np.max(onn_weights.flatten()):
|
||||
# weights[1] = np.max(onn_weights.flatten())
|
||||
|
||||
model_params.append({layer_name: onn_weights})
|
||||
plots.append({layer_name: (num, onn_values, onn_angles, act_values, act_angles)})
|
||||
|
||||
# fig, axs = plt.subplots(3, len(model_params)*2-1, figsize=(20, 5))
|
||||
|
||||
for plot in plots:
|
||||
layer_name, (num, onn_values, onn_angles, act_values, act_angles) = plot.popitem()
|
||||
# for_plot[:, :, 0] = (for_plot[:, :, 0] - angles[0]) / (angles[1] - angles[0])
|
||||
# for_plot[:, :, 1] = (for_plot[:, :, 1] - weights[0]) / (weights[1] - weights[0])
|
||||
|
||||
onn_values = np.ma.array(onn_values, mask=np.isnan(onn_values))
|
||||
onn_values = onn_values - np.min(onn_values)
|
||||
onn_values = onn_values / np.max(onn_values)
|
||||
|
||||
act_values = np.ma.array(act_values, mask=np.isnan(act_values))
|
||||
act_values = act_values - np.min(act_values)
|
||||
act_values = act_values / np.max(act_values)
|
||||
|
||||
|
||||
onn_values = onn_values
|
||||
onn_values = pad_to_size(onn_values, (max_size, None))
|
||||
|
||||
act_values = act_values
|
||||
act_values = pad_to_size(act_values, (max_size, 3))
|
||||
|
||||
onn_angles = onn_angles / np.pi
|
||||
onn_angles = pad_to_size(onn_angles, (max_size, None))
|
||||
|
||||
act_angles = act_angles / np.pi
|
||||
act_angles = pad_to_size(act_angles, (max_size, 3))
|
||||
|
||||
|
||||
|
||||
# onn_angles = onn_angles - np.min(onn_angles)
|
||||
# onn_angles = onn_angles / np.max(onn_angles)
|
||||
|
||||
# act_angles = act_angles - np.min(act_angles)
|
||||
# act_angles = act_angles / np.max(act_angles)
|
||||
|
||||
if num == 0:
|
||||
value_img = np.concatenate((onn_values, act_values), axis=1)
|
||||
angle_img = np.concatenate((onn_angles, act_angles), axis=1)
|
||||
else:
|
||||
value_img = np.concatenate((value_img, onn_values, act_values), axis=1)
|
||||
angle_img = np.concatenate((angle_img, onn_angles, act_angles), axis=1)
|
||||
|
||||
|
||||
|
||||
|
||||
# -np.inf to np.nan
|
||||
# value_img[value_img == -np.inf] = np.nan
|
||||
|
||||
# angle_img += move_to_location_in_size(onn_angles, ((max_size+3)*num, 0), img_overall_size)
|
||||
# angle_img += move_to_location_in_size(act_angles, ((max_size+3)*(num+1) + 2, 0), img_overall_size)
|
||||
|
||||
|
||||
|
||||
|
||||
from cmcrameri import cm
|
||||
from matplotlib import colors as mcolors
|
||||
alpha_map = mcolors.LinearSegmentedColormap(
|
||||
'alphamap',
|
||||
{
|
||||
'red': [(0, 0, 0), (1, 0, 0)],
|
||||
'green': [(0, 0, 0), (1, 0, 0)],
|
||||
'blue': [(0, 0, 0), (1, 0, 0)],
|
||||
'alpha': [
|
||||
(0, 1, 1),
|
||||
# (0.2, 0.2, 0.1),
|
||||
(1, 0, 0)
|
||||
]
|
||||
}
|
||||
)
|
||||
alpha_map.set_bad(color="#AAAAAA")
|
||||
|
||||
|
||||
fig, axs = plt.subplots(3, 1, sharex=True, sharey=True, figsize=(7, 8.5))
|
||||
fig.tight_layout()
|
||||
# masked_value_img = np.ma.masked_where(np.isnan(value_img), value_img)
|
||||
masked_value_img = value_img
|
||||
cmap = cm.batlowW
|
||||
cmap.set_bad(color="#AAAAAA")
|
||||
im_val = axs[0].imshow(masked_value_img, cmap=cmap)
|
||||
|
||||
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
|
||||
cmap = cm.romaO
|
||||
cmap.set_bad(color="#AAAAAA")
|
||||
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap)
|
||||
im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap)
|
||||
im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map)
|
||||
|
||||
axs[0].axis("off")
|
||||
axs[1].axis("off")
|
||||
axs[2].axis("off")
|
||||
|
||||
axs[0].set_title("Values")
|
||||
axs[1].set_title("Angles")
|
||||
axs[2].set_title("Values and Angles")
|
||||
|
||||
|
||||
...
|
||||
plt.show()
|
||||
# model = models.regenerator(*dims, **model_kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_plot(".models/best_20250105_145719.tar")
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
from pathlib import Path
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
@@ -13,7 +13,7 @@ from hypertraining.settings import (
|
||||
OptimizerSettings,
|
||||
)
|
||||
|
||||
from hypertraining.training import RegenerationTrainer, PolarizationTrainer
|
||||
from hypertraining.training import RegenerationTrainer#, PolarizationTrainer
|
||||
|
||||
# import torch
|
||||
import json
|
||||
@@ -27,7 +27,7 @@ global_settings = GlobalSettings(
|
||||
|
||||
data_settings = DataSettings(
|
||||
# config_path="data/*-128-16384-1-0-0-0-0-PAM4-0-0.ini",
|
||||
config_path="data/*-128-16384-10000-0-0-17-0-PAM4-0.ini",
|
||||
config_path="data/20250110-190528-128-16384-100000-0-0.2-17.0-0.058-PAM4-0-0.14-10.ini",
|
||||
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
||||
dtype="complex64",
|
||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||
@@ -37,12 +37,13 @@ data_settings = DataSettings(
|
||||
shuffle=True,
|
||||
drop_first=64,
|
||||
train_split=0.8,
|
||||
randomise_polarisations=True,
|
||||
osnr=10,
|
||||
randomise_polarisations=False,
|
||||
polarisations=True,
|
||||
osnr=16, #16dB due to amplification with NF 5
|
||||
)
|
||||
|
||||
pytorch_settings = PytorchSettings(
|
||||
epochs=10000,
|
||||
epochs=1000,
|
||||
batchsize=2**14,
|
||||
device="cuda",
|
||||
dataloader_workers=24,
|
||||
@@ -64,11 +65,11 @@ model_settings = ModelSettings(
|
||||
# "n_hidden_nodes_3": 4,
|
||||
# "n_hidden_nodes_4": 2,
|
||||
},
|
||||
model_activation_func="EOActivation",
|
||||
model_activation_func="phase_shift",
|
||||
dropout_prob=0,
|
||||
model_layer_function="ONNRect",
|
||||
model_layer_kwargs={"square": True},
|
||||
scale=False,
|
||||
scale=2.0,
|
||||
model_layer_parametrizations=[
|
||||
{
|
||||
"tensor_name": "weight",
|
||||
@@ -77,13 +78,17 @@ model_settings = ModelSettings(
|
||||
{
|
||||
"tensor_name": "alpha",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "gain",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": float("inf"),
|
||||
"max": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -95,8 +100,12 @@ model_settings = ModelSettings(
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "scales",
|
||||
"tensor_name": "scale",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": 2,
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "angle",
|
||||
@@ -244,9 +253,17 @@ if __name__ == "__main__":
|
||||
pytorch_settings=pytorch_settings,
|
||||
model_settings=model_settings,
|
||||
optimizer_settings=optimizer_settings,
|
||||
checkpoint_path=".models/best_20241216_221359.tar",
|
||||
# checkpoint_path=".models/best_20250104_191428.tar",
|
||||
reset_epoch=True,
|
||||
# settings_override={
|
||||
# "data_settings": {
|
||||
# "config_path": "data/20241229-163*-128-16384-100000-*.ini",
|
||||
# "polarisations": True,
|
||||
# },
|
||||
# "model_settings": {
|
||||
# "scale": 2.0,
|
||||
# }
|
||||
# }
|
||||
# "optimizer_settings": {
|
||||
# "optimizer_kwargs": {
|
||||
# "lr": 0.01,
|
||||
|
||||
@@ -16,16 +16,17 @@ from datetime import datetime
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
import time
|
||||
import h5py
|
||||
from matplotlib import pyplot as plt # noqa: F401
|
||||
import numpy as np
|
||||
|
||||
import add_pypho # noqa: F401
|
||||
from . import add_pypho # noqa: F401
|
||||
import pypho
|
||||
|
||||
default_config = f"""
|
||||
[glova]
|
||||
nos = 256
|
||||
sps = 256
|
||||
sps = 128
|
||||
nos = 16384
|
||||
f0 = 193414489032258.06
|
||||
symbolrate = 10e9
|
||||
wisdom_dir = "{str((Path.home() / ".pypho"))}"
|
||||
@@ -37,9 +38,9 @@ length = 10000
|
||||
gamma = 1.14
|
||||
alpha = 0.2
|
||||
D = 17
|
||||
S = 0
|
||||
birefsteps = 0
|
||||
max_delta_beta = 0.4
|
||||
S = 0.058
|
||||
bireflength = 10
|
||||
max_delta_beta = 0.14
|
||||
; birefseed = 0xC0FFEE
|
||||
|
||||
[signal]
|
||||
@@ -47,17 +48,15 @@ max_delta_beta = 0.4
|
||||
|
||||
modulation = "pam"
|
||||
mod_order = 4
|
||||
mod_depth = 0.8
|
||||
|
||||
mod_depth = 1
|
||||
max_jitter = 0.02
|
||||
; jitter_seed = 0xC0FFEE
|
||||
|
||||
laser_power = 0
|
||||
edfa_power = 3
|
||||
edfa_power = 0
|
||||
edfa_nf = 5
|
||||
|
||||
pulse_shape = "gauss"
|
||||
fwhm = 0.33
|
||||
osnr = "inf"
|
||||
|
||||
[data]
|
||||
dir = "data"
|
||||
@@ -71,6 +70,7 @@ def get_config(config_file=None):
|
||||
"""
|
||||
if config_file is None:
|
||||
config_file = Path(__file__).parent / "signal_generation.ini"
|
||||
config_file = Path(config_file)
|
||||
if not config_file.exists():
|
||||
with open(config_file, "w") as f:
|
||||
f.write(default_config)
|
||||
@@ -83,7 +83,10 @@ def get_config(config_file=None):
|
||||
conf[section] = {}
|
||||
for key in config[section]:
|
||||
# print(f"{key} = {config[section][key]}")
|
||||
conf[section][key] = eval(config[section][key])
|
||||
try:
|
||||
conf[section][key] = eval(config[section][key])
|
||||
except NameError:
|
||||
conf[section][key] = float(config[section][key])
|
||||
# if isinstance(conf[section][key], str):
|
||||
# conf[section][key] = config[section][key].strip('"')
|
||||
return conf
|
||||
@@ -96,7 +99,9 @@ class PDM_IM_IPM:
|
||||
mod_order=8,
|
||||
seed=None,
|
||||
):
|
||||
assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, "mod_order must be a cube of an integer greater than 1"
|
||||
assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, (
|
||||
"mod_order must be a cube of an integer greater than 1"
|
||||
)
|
||||
self.glova = glova
|
||||
self.mod_order = mod_order
|
||||
self.symbols_per_dim = int(np.cbrt(mod_order))
|
||||
@@ -106,18 +111,11 @@ class PDM_IM_IPM:
|
||||
rs = np.random.RandomState(self.seed)
|
||||
symbols = rs.randint(0, self.mod_order, n)
|
||||
return symbols
|
||||
|
||||
|
||||
|
||||
class pam_generator:
|
||||
def __init__(
|
||||
self,
|
||||
glova,
|
||||
mod_order=None,
|
||||
mod_depth=0.5,
|
||||
pulse_shape="gauss",
|
||||
fwhm=0.33,
|
||||
seed=None,
|
||||
single_channel=False
|
||||
self, glova, mod_order=None, mod_depth=0.5, pulse_shape="gauss", fwhm=0.33, seed=None, single_channel=False
|
||||
) -> None:
|
||||
self.glova = glova
|
||||
self.pulse_shape = pulse_shape
|
||||
@@ -133,41 +131,36 @@ class pam_generator:
|
||||
wavelet = self.gauss(oversampling=6)
|
||||
else:
|
||||
raise ValueError(f"Unknown pulse shape: {self.pulse_shape}")
|
||||
|
||||
|
||||
# prepare symbols
|
||||
symbols_x = symbols[0] / (self.mod_order)
|
||||
diffs_x = np.diff(symbols_x, prepend=symbols_x[0])
|
||||
digital_x = self.generate_digital_signal(diffs_x, max_jitter)
|
||||
digital_x = np.pad(
|
||||
digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)
|
||||
)
|
||||
|
||||
digital_x = np.pad(digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
|
||||
|
||||
# create analog signal of diff of symbols
|
||||
E_x = np.convolve(digital_x, wavelet)
|
||||
|
||||
|
||||
# convert to pam and set modulation depth (scale and move up such that 1 stays at 1)
|
||||
E_x = np.cumsum(E_x) * self.modulation_depth + (1 - self.modulation_depth)
|
||||
|
||||
|
||||
# cut off the wavelet tails
|
||||
E_x = E_x[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
||||
|
||||
|
||||
# modulate the laser
|
||||
E[0]["E"][0] = np.sqrt(np.multiply(np.square(E[0]["E"][0]), E_x))
|
||||
|
||||
|
||||
if not self.single_channel:
|
||||
symbols_y = symbols[1] / (self.mod_order)
|
||||
diffs_y = np.diff(symbols_y, prepend=symbols_y[0])
|
||||
digital_y = self.generate_digital_signal(diffs_y, max_jitter)
|
||||
digital_y = np.pad(
|
||||
digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)
|
||||
)
|
||||
digital_y = np.pad(digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
|
||||
E_y = np.convolve(digital_y, wavelet)
|
||||
|
||||
E_y = np.cumsum(E_y) * self.modulation_depth + (1 - self.modulation_depth)
|
||||
|
||||
E_y = E_y[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
||||
|
||||
|
||||
E[0]["E"][1] = np.sqrt(np.multiply(np.square(E[0]["E"][1]), E_y))
|
||||
|
||||
# rotate the signal on the y-polarisation by 90°
|
||||
@@ -175,7 +168,6 @@ class pam_generator:
|
||||
else:
|
||||
E[0]["E"][1] = np.zeros_like(E[0]["E"][0], dtype=E[0]["E"][0].dtype)
|
||||
return E
|
||||
|
||||
|
||||
def generate_digital_signal(self, symbols, max_jitter=0):
|
||||
rs = np.random.RandomState(self.seed)
|
||||
@@ -198,15 +190,11 @@ class pam_generator:
|
||||
endpoint=True,
|
||||
)
|
||||
sigma = self.fwhm / (1 * np.sqrt(2 * np.log(2))) * self.glova.sps
|
||||
pulse = (
|
||||
1
|
||||
/ (sigma * np.sqrt(2 * np.pi))
|
||||
* np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
|
||||
)
|
||||
pulse = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
|
||||
return pulse
|
||||
|
||||
|
||||
def initialize_fiber_and_data(config, input_data_override=None):
|
||||
def initialize_fiber_and_data(config):
|
||||
py_glova = pypho.setup(
|
||||
nos=config["glova"]["nos"],
|
||||
sps=config["glova"]["sps"],
|
||||
@@ -221,48 +209,54 @@ def initialize_fiber_and_data(config, input_data_override=None):
|
||||
c_data = pypho.cfiber.DataWrapper(py_glova.sps * py_glova.nos)
|
||||
py_edfa = pypho.oamp(py_glova, Pmean=config["signal"]["edfa_power"], NF=config["signal"]["edfa_nf"])
|
||||
|
||||
if input_data_override is not None:
|
||||
c_data.E_in = input_data_override[0]
|
||||
noise = input_data_override[1]
|
||||
else:
|
||||
config["signal"]["seed"] = config["signal"].get(
|
||||
"seed", (int(time.time() * 1000)) % 2**32
|
||||
)
|
||||
config["signal"]["jitter_seed"] = config["signal"].get(
|
||||
"jitter_seed", (int(time.time() * 1000)) % 2**32
|
||||
)
|
||||
symbolsrc = pypho.symbols(
|
||||
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
|
||||
)
|
||||
laser = pypho.lasmod(
|
||||
py_glova, power=config["signal"]["laser_power"]+1.5, Df=0, theta=np.pi / 4
|
||||
)
|
||||
modulator = pam_generator(
|
||||
py_glova,
|
||||
mod_depth=config["signal"]["mod_depth"],
|
||||
pulse_shape=config["signal"]["pulse_shape"],
|
||||
fwhm=config["signal"]["fwhm"],
|
||||
seed=config["signal"]["jitter_seed"],
|
||||
mod_order=config["signal"]["mod_order"],
|
||||
osnr = config["signal"]["osnr"] = float(config["signal"].get("osnr", "inf"))
|
||||
|
||||
config["signal"]["seed"] = config["signal"].get("seed", (int(time.time() * 1000)) % 2**32)
|
||||
config["signal"]["jitter_seed"] = config["signal"].get("jitter_seed", (int(time.time() * 1000)) % 2**32)
|
||||
symbolsrc = pypho.symbols(
|
||||
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
|
||||
)
|
||||
laser = pypho.lasmod(py_glova, power=config["signal"]["laser_power"], Df=0, theta=np.pi / 4)
|
||||
modulator = pam_generator(
|
||||
py_glova,
|
||||
mod_depth=config["signal"]["mod_depth"],
|
||||
pulse_shape=config["signal"]["pulse_shape"],
|
||||
fwhm=config["signal"]["fwhm"],
|
||||
seed=config["signal"]["jitter_seed"],
|
||||
mod_order=config["signal"]["mod_order"],
|
||||
)
|
||||
|
||||
symbols_x = symbolsrc(pattern="random")
|
||||
symbols_y = symbolsrc(pattern="random")
|
||||
symbols_x[:3] = 0
|
||||
symbols_y[:3] = 0
|
||||
# symbols_x += 1
|
||||
|
||||
cw = laser()
|
||||
|
||||
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
|
||||
|
||||
if osnr != float("inf"):
|
||||
osnr_lin = 10 ** (osnr / 10)
|
||||
signal_power = np.sum(pypho.functions.getpower_W(source_signal[0]["E"]))
|
||||
noise_power = signal_power / osnr_lin
|
||||
noise = np.random.normal(0, 1, source_signal[0]["E"].shape) + 1j * np.random.normal(
|
||||
0, 1, source_signal[0]["E"].shape
|
||||
)
|
||||
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
|
||||
noise = noise * np.sqrt(noise_power / noise_power_is)
|
||||
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
|
||||
source_signal[0]["E"] += noise
|
||||
source_signal[0]["noise"] = noise_power_is
|
||||
|
||||
symbols_x = symbolsrc(pattern="random")
|
||||
symbols_y = symbolsrc(pattern="random")
|
||||
symbols_x[:3] = 0
|
||||
symbols_y[:3] = 0
|
||||
# symbols_x += 1
|
||||
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
|
||||
|
||||
nf = py_edfa.NF
|
||||
source_signal = py_edfa(E=source_signal, NF=0)
|
||||
py_edfa.NF = nf
|
||||
|
||||
cw = laser()
|
||||
|
||||
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
|
||||
|
||||
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
|
||||
|
||||
source_signal = py_edfa(E=source_signal)
|
||||
|
||||
c_data.E_in = source_signal[0]["E"]
|
||||
noise = source_signal[0]["noise"]
|
||||
c_data.E_in = source_signal[0]["E"]
|
||||
noise = source_signal[0]["noise"]
|
||||
|
||||
py_fiber = pypho.fiber(
|
||||
glova=py_glova,
|
||||
@@ -273,25 +267,21 @@ def initialize_fiber_and_data(config, input_data_override=None):
|
||||
S=config["fiber"]["s"],
|
||||
)
|
||||
if config["fiber"].get("birefsteps", 0) > 0:
|
||||
seed = config["fiber"].get(
|
||||
"birefseed", (int(time.time() * 1000)) % 2**32
|
||||
)
|
||||
seed = config["fiber"].get("birefseed", (int(time.time() * 1000)) % 2**32)
|
||||
py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre(
|
||||
py_fiber.l,
|
||||
py_fiber.l / config["fiber"]["birefsteps"],
|
||||
# maxDeltaD=config["fiber"]["d"]/5,
|
||||
maxDeltaBeta = config["fiber"].get("max_delta_beta", 0),
|
||||
maxDeltaBeta=config["fiber"].get("max_delta_beta", 0),
|
||||
seed=seed,
|
||||
)
|
||||
c_params = pypho.cfiber.ParamsWrapper.from_fiber(
|
||||
py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200
|
||||
)
|
||||
c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200)
|
||||
c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova)
|
||||
|
||||
return c_fiber, c_data, noise, py_edfa
|
||||
return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y)
|
||||
|
||||
|
||||
def save_data(data, config):
|
||||
def save_data(data, config, **metadata):
|
||||
data_dir = Path(config["data"]["dir"])
|
||||
npy_dir = config["data"].get("npy_dir", "")
|
||||
save_dir = data_dir / npy_dir if len(npy_dir) else data_dir
|
||||
@@ -306,6 +296,7 @@ def save_data(data, config):
|
||||
seed = config["signal"].get("seed", False)
|
||||
jitter_seed = config["signal"].get("jitter_seed", False)
|
||||
birefseed = config["fiber"].get("birefseed", False)
|
||||
osnr = float(config["signal"].get("osnr", "inf"))
|
||||
|
||||
config_content = "\n".join((
|
||||
f"; Generated by {str(Path(__file__).name)} @ {timestamp.strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
@@ -317,14 +308,14 @@ def save_data(data, config):
|
||||
f'wisdom_dir = "{config["glova"]["wisdom_dir"]}"',
|
||||
f'flags = "{config["glova"]["flags"]}"',
|
||||
f"nthreads = {config['glova']['nthreads']}",
|
||||
" ",
|
||||
"",
|
||||
"[fiber]",
|
||||
f"length = {config['fiber']['length']}",
|
||||
f"gamma = {config['fiber']['gamma']}",
|
||||
f"alpha = {config['fiber']['alpha']}",
|
||||
f"D = {config['fiber']['d']}",
|
||||
f"S = {config['fiber']['s']}",
|
||||
f"birefsteps = {config['fiber'].get('birefsteps',0)}",
|
||||
f"birefsteps = {config['fiber'].get('birefsteps', 0)}",
|
||||
f"max_delta_beta = {config['fiber'].get('max_delta_beta', 0)}",
|
||||
f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set",
|
||||
"",
|
||||
@@ -334,75 +325,84 @@ def save_data(data, config):
|
||||
f'modulation = "{config["signal"]["modulation"]}"',
|
||||
f"mod_order = {config['signal']['mod_order']}",
|
||||
f"mod_depth = {config['signal']['mod_depth']}",
|
||||
""
|
||||
"",
|
||||
f"max_jitter = {config['signal']['max_jitter']}",
|
||||
f"jitter_seed = {hex(jitter_seed)}" if jitter_seed else "; jitter_seed = not set",
|
||||
""
|
||||
"",
|
||||
f"laser_power = {config['signal']['laser_power']}",
|
||||
f"edfa_power = {config['signal']['edfa_power']}",
|
||||
f"edfa_nf = {config['signal']['edfa_nf']}",
|
||||
""
|
||||
f"osnr = {osnr}",
|
||||
"",
|
||||
f'pulse_shape = "{config["signal"]["pulse_shape"]}"',
|
||||
f"fwhm = {config['signal']['fwhm']}",
|
||||
"",
|
||||
"[data]",
|
||||
f'dir = "{str(data_dir)}"',
|
||||
f'npy_dir = "{npy_dir}"',
|
||||
"file = "
|
||||
"file = ",
|
||||
))
|
||||
config_hash = hashlib.md5(config_content.encode()).hexdigest()
|
||||
save_file = f"{config_hash}.npy"
|
||||
save_file = f"{config_hash}.h5"
|
||||
config_content += f'"{str(save_file)}"\n'
|
||||
|
||||
filename_components = (
|
||||
timestamp.strftime("%Y%m%d-%H%M%S"),
|
||||
config["glova"]["sps"],
|
||||
config["glova"]["nos"],
|
||||
config["signal"]["osnr"],
|
||||
config["fiber"]["length"],
|
||||
config["fiber"]["gamma"],
|
||||
config["fiber"]["alpha"],
|
||||
config["fiber"]["d"],
|
||||
config["fiber"]["s"],
|
||||
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
|
||||
config['fiber'].get('birefsteps',0),
|
||||
config["fiber"].get("birefsteps", 0),
|
||||
config["fiber"].get("max_delta_beta", 0),
|
||||
int(config["glova"]["symbolrate"] / 1e9),
|
||||
)
|
||||
|
||||
lookup_file = "-".join(map(str, filename_components)) + ".ini"
|
||||
with open(data_dir / lookup_file, "w") as f:
|
||||
config_filename = data_dir / lookup_file
|
||||
with open(config_filename, "w") as f:
|
||||
f.write(config_content)
|
||||
|
||||
np.save(save_dir / save_file, save_data)
|
||||
with h5py.File(save_dir / save_file, "w") as outfile:
|
||||
outfile.create_dataset("data", data=save_data)
|
||||
outfile.create_dataset("symbols", data=metadata.pop("symbols"))
|
||||
for key, value in metadata.items():
|
||||
# if isinstance(value, dict):
|
||||
# value = json.dumps(model_runner.convert_arrays(value))
|
||||
outfile.attrs[key] = value
|
||||
# np.save(save_dir / save_file, save_data)
|
||||
|
||||
print("Saved config to", data_dir / lookup_file)
|
||||
print("Saved config to", config_filename)
|
||||
print("Saved data to", save_dir / save_file)
|
||||
|
||||
return config_filename
|
||||
|
||||
|
||||
def length_loop(config, lengths, save=True):
|
||||
lengths = sorted(lengths)
|
||||
for length in lengths:
|
||||
print(f"\nGenerating data for fiber length {length}m")
|
||||
config["fiber"]["length"] = length
|
||||
|
||||
cfiber, cdata, noise, edfa = initialize_fiber_and_data(config)
|
||||
config["fiber"]["length"] = length
|
||||
|
||||
cfiber, cdata, noise, edfa = initialize_fiber_and_data(config)
|
||||
|
||||
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
|
||||
cfiber()
|
||||
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
||||
|
||||
print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)")
|
||||
print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)")
|
||||
|
||||
print(
|
||||
f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)"
|
||||
)
|
||||
print(
|
||||
f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)"
|
||||
)
|
||||
|
||||
|
||||
E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}]
|
||||
E_tmp = [{"E": cdata.E_out, "noise": noise * (-cfiber.params.l * cfiber.params.alpha)}]
|
||||
E_tmp = edfa(E=E_tmp)
|
||||
cdata.E_out = E_tmp[0]['E']
|
||||
cdata.E_out = E_tmp[0]["E"]
|
||||
|
||||
mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
||||
print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)")
|
||||
|
||||
if save:
|
||||
save_data(cdata, config)
|
||||
@@ -411,27 +411,57 @@ def length_loop(config, lengths, save=True):
|
||||
|
||||
|
||||
def single_run_with_plot(config, save=True):
|
||||
cfiber, cdata, noise, edfa = initialize_fiber_and_data(config)
|
||||
cfiber, cdata, config_filename = single_run(config, save)
|
||||
|
||||
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
|
||||
print(
|
||||
f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)"
|
||||
)
|
||||
in_out_eyes(cfiber, cdata, show_pols=False)
|
||||
return config_filename
|
||||
|
||||
def single_run(config, save=True):
|
||||
cfiber, cdata, noise, edfa, symbols = initialize_fiber_and_data(config)
|
||||
|
||||
# mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
|
||||
# print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)")
|
||||
|
||||
# estimate osnr
|
||||
# noise_power = np.mean(noise)
|
||||
# osnr_lin = mean_power_in / noise_power - 1
|
||||
# osnr = 10 * np.log10(osnr_lin)
|
||||
# print(f"Estimated OSNR: {osnr:.3f} dB")
|
||||
|
||||
cfiber()
|
||||
|
||||
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
||||
print(
|
||||
f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)"
|
||||
)
|
||||
# mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
||||
# print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)")
|
||||
|
||||
E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}]
|
||||
# noise = noise * np.exp(-cfiber.params.l * cfiber.params.alpha)
|
||||
|
||||
# estimate osnr
|
||||
# noise_power = np.mean(noise)
|
||||
# osnr_lin = mean_power_out / noise_power - 1
|
||||
# osnr = 10 * np.log10(osnr_lin)
|
||||
# print(f"Estimated OSNR: {osnr:.3f} dB")
|
||||
|
||||
E_tmp = [{"E": cdata.E_out, "noise": noise}]
|
||||
E_tmp = edfa(E=E_tmp)
|
||||
cdata.E_out = E_tmp[0]['E']
|
||||
if save:
|
||||
save_data(cdata, config)
|
||||
cdata.E_out = E_tmp[0]["E"]
|
||||
# noise = E_tmp[0]["noise"]
|
||||
|
||||
# mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
||||
|
||||
# print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)")
|
||||
|
||||
# estimate osnr
|
||||
# noise_power = np.mean(noise)
|
||||
# osnr_lin = mean_power_amp / noise_power - 1
|
||||
# osnr = 10 * np.log10(osnr_lin)
|
||||
# print(f"Estimated OSNR: {osnr:.3f} dB")
|
||||
|
||||
config_filename = None
|
||||
symbols = np.array(symbols)
|
||||
if save:
|
||||
config_filename = save_data(cdata, config, **{"symbols": symbols})
|
||||
return cfiber,cdata,config_filename
|
||||
|
||||
in_out_eyes(cfiber, cdata, show_pols=False)
|
||||
|
||||
def in_out_eyes(cfiber, cdata, show_pols=False):
|
||||
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
|
||||
@@ -595,9 +625,7 @@ def plot_eye_diagram(
|
||||
signal = signal[: head * eye_width]
|
||||
if normalize:
|
||||
signal = signal / np.max(signal)
|
||||
slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[
|
||||
offset % (eye_width + 1) :: eye_width
|
||||
]
|
||||
slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[offset % (eye_width + 1) :: eye_width]
|
||||
plt_ax = np.arange(-eye_width // 2, eye_width // 2 + 1) / samplerate
|
||||
for slice in slices:
|
||||
ax.plot(plt_ax, slice, color=color, alpha=0.1)
|
||||
@@ -617,15 +645,27 @@ if __name__ == "__main__":
|
||||
# lengths.append(10*max(ranges))
|
||||
# lengths = [*lengths, *lengths]
|
||||
lengths = (
|
||||
# 8000, 9000,
|
||||
10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000,
|
||||
95000, 100000, 105000, 110000, 115000, 120000
|
||||
# 8000, 9000,
|
||||
10000,
|
||||
20000,
|
||||
30000,
|
||||
40000,
|
||||
50000,
|
||||
60000,
|
||||
70000,
|
||||
80000,
|
||||
90000,
|
||||
95000,
|
||||
100000,
|
||||
105000,
|
||||
110000,
|
||||
115000,
|
||||
120000,
|
||||
)
|
||||
|
||||
# lengths = (10000,100000)
|
||||
|
||||
length_loop(config, lengths, save=True)
|
||||
# length_loop(config, lengths, save=True)
|
||||
# birefringence is constant over coupling length -> several 100m -> bireflength=1000 (m)
|
||||
|
||||
# single_run_with_plot(config, save=False)
|
||||
|
||||
single_run_with_plot(config, save=False)
|
||||
723
src/single-core-regen/tolerance_testing.py
Normal file
723
src/single-core-regen/tolerance_testing.py
Normal file
@@ -0,0 +1,723 @@
|
||||
"""
|
||||
tests a given model for tolerance against variations in
|
||||
- fiber length
|
||||
- baudrate
|
||||
- OSNR
|
||||
|
||||
CD, PMD, baudrate need different datasets, osnr is modeled as awgn added to the data before feeding into the model
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
import h5py
|
||||
import torch
|
||||
import util
|
||||
from hypertraining.settings import GlobalSettings, DataSettings, ModelSettings, OptimizerSettings, PytorchSettings
|
||||
from hypertraining import models
|
||||
|
||||
from signal_gen.generate_signal import single_run, get_config
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class NestedParameterIterator:
|
||||
def __init__(self, parameters):
|
||||
"""
|
||||
parameters: dict with key <param_name> and value <dict with keys "config" and "range">
|
||||
"""
|
||||
# self.parameters = parameters
|
||||
self.names = []
|
||||
self.ranges = []
|
||||
self.configs = []
|
||||
for k, v in parameters.items():
|
||||
self.names.append(k)
|
||||
self.ranges.append(v["range"])
|
||||
self.configs.append(v["config"])
|
||||
self.n_parameters = len(self.ranges)
|
||||
|
||||
self.idx = 0
|
||||
|
||||
self.range_idx = [0] * self.n_parameters
|
||||
self.range_len = [len(r) for r in self.ranges]
|
||||
self.length = int(np.prod(self.range_len))
|
||||
self.out = []
|
||||
|
||||
for i in range(self.length):
|
||||
self.out.append([])
|
||||
for j in range(self.n_parameters):
|
||||
element = {self.names[j]: {"value": self.ranges[j][self.range_idx[j]], "config": self.configs[j]}}
|
||||
self.out[i].append(element)
|
||||
self.range_idx[-1] += 1
|
||||
|
||||
# update range_idx back to front
|
||||
for j in range(self.n_parameters - 1, -1, -1):
|
||||
if self.range_idx[j] == self.range_len[j]:
|
||||
self.range_idx[j] = 0
|
||||
self.range_idx[j - 1] += 1
|
||||
|
||||
...
|
||||
|
||||
def __next__(self):
|
||||
if self.idx == self.length:
|
||||
raise StopIteration
|
||||
self.idx += 1
|
||||
return self.out[self.idx - 1]
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
|
||||
class model_runner:
|
||||
def __init__(
|
||||
self,
|
||||
# length_range: tuple[int | float] = (50e3, 50e3),
|
||||
# length_steps: int = 1,
|
||||
# length_log: bool = False,
|
||||
# baudrate_range: tuple[int | float] = (10e9, 10e9),
|
||||
# baudrate_steps: int = 1,
|
||||
# baudrate_log: bool = False,
|
||||
# osnr_range: tuple[int | float] = (16, 16),
|
||||
# osnr_steps: int = 1,
|
||||
# osnr_log: bool = False,
|
||||
# dataset_dir: str = "data",
|
||||
# dataset_datetime_glob: str = "*",
|
||||
results_dir: str = "tolerance_results/datasets",
|
||||
# model_dir: str = ".models",
|
||||
config: str = "signal_generation.ini",
|
||||
config_dir: str = None,
|
||||
debug: bool = False,
|
||||
):
|
||||
"""
|
||||
length_range: lower and upper limit of length, in meters
|
||||
length_step: step size of length, in meters
|
||||
baudrate_range: lower and upper limit of baudrate, in Bd
|
||||
baudrate_step: step size of baudrate, in Bd
|
||||
osnr_range: lower and upper limit of osnr, in dB
|
||||
osnr_step: step size of osnr, in dB
|
||||
dataset_dir: directory containing datasets
|
||||
dataset_datetime_glob: datetime glob pattern for dataset files
|
||||
results_dir: directory to save results
|
||||
model_dir: directory containing models
|
||||
"""
|
||||
self.debug = debug
|
||||
|
||||
self.parameters = {}
|
||||
self.iter = None
|
||||
|
||||
# self.update_length_range(length_range, length_steps, length_log)
|
||||
# self.update_baudrate_range(baudrate_range, baudrate_steps, baudrate_log)
|
||||
# self.update_osnr_range(osnr_range, osnr_steps, osnr_log)
|
||||
|
||||
# self.data_dir = Path(dataset_dir)
|
||||
# self.data_datetime_glob = dataset_datetime_glob
|
||||
self.results_dir = Path(results_dir)
|
||||
# self.model_dir = Path(model_dir)
|
||||
|
||||
config_dir = config_dir or Path(__file__).parent
|
||||
self.config = config_dir / config
|
||||
|
||||
torch.serialization.add_safe_globals([
|
||||
*util.complexNN.__all__,
|
||||
GlobalSettings,
|
||||
DataSettings,
|
||||
ModelSettings,
|
||||
OptimizerSettings,
|
||||
PytorchSettings,
|
||||
models.regenerator,
|
||||
torch.nn.utils.parametrizations.orthogonal,
|
||||
])
|
||||
|
||||
self.load_model()
|
||||
|
||||
self.datasets = []
|
||||
|
||||
# def register_parameter(self, name, config):
|
||||
# self.parameters.append({"name": name, "config": config})
|
||||
|
||||
def load_results_from_file(self, path):
|
||||
data, meta = self.load_from_file(path)
|
||||
self.results = [d.decode() for d in data]
|
||||
self.parameters = meta["parameters"]
|
||||
...
|
||||
|
||||
def load_datasets_from_file(self, path):
|
||||
data, meta = self.load_from_file(path)
|
||||
self.datasets = [d.decode() for d in data]
|
||||
self.parameters = meta["parameters"]
|
||||
...
|
||||
|
||||
def update_parameter_range(self, name, config, range, steps, log):
|
||||
self.parameters[name] = {"config": config, "range": self.update_range(*range, steps, log)}
|
||||
|
||||
def generate_iterations(self):
|
||||
if len(self.parameters) == 0:
|
||||
raise ValueError("No parameters registered")
|
||||
self.iter = NestedParameterIterator(self.parameters)
|
||||
|
||||
def generate_datasets(self):
|
||||
# get base config
|
||||
config = get_config(self.config)
|
||||
|
||||
if self.iter is None:
|
||||
self.generate_iterations()
|
||||
|
||||
for params in self.iter:
|
||||
current_settings = []
|
||||
# params is a list of dictionaries with keys "name", containing a dict with keys "value", "config"
|
||||
for param in params:
|
||||
for name, settings in param.items():
|
||||
current_settings.append({name: settings["value"]})
|
||||
self.nested_set(config, settings["config"], settings["value"])
|
||||
settings_strs = []
|
||||
for setting in current_settings:
|
||||
name = list(setting)[0]
|
||||
settings_strs.append(f"{name}: {float(setting[name]):.2e}")
|
||||
settings_str = ", ".join(settings_strs)
|
||||
print(f"Generating dataset for [{settings_str}]")
|
||||
# TODO look for existing datasets
|
||||
_, _, path = single_run(config)
|
||||
self.datasets.append(str(path))
|
||||
|
||||
datasets_list_path = self.build_path("datasets_list", parent_dir=self.results_dir, timestamp="back")
|
||||
metadata = {"parameters": self.parameters}
|
||||
data = np.array(self.datasets, dtype="S")
|
||||
self.save_to_file(datasets_list_path, data, **metadata)
|
||||
|
||||
@staticmethod
|
||||
def nested_set(dic, keys, value):
|
||||
for key in keys[:-1]:
|
||||
dic = dic.setdefault(key, {})
|
||||
dic[keys[-1]] = value
|
||||
|
||||
## Dataset and model loading
|
||||
# def find_datasets(self, data_dir=None, data_datetime_glob=None):
|
||||
# # date-time-sps-nos-length-gamma-alpha-D-S-PAM4-birefsteps-deltabeta-symbolrate.ini
|
||||
# data_dir = data_dir or self.data_dir
|
||||
# data_datetime_glob = data_datetime_glob or self.data_datetime_glob
|
||||
# self.datasets = {}
|
||||
# data_dir = Path(data_dir)
|
||||
# for length in self.lengths:
|
||||
# for baudrate in self.baudrates:
|
||||
# # dataset_glob = self.data_datetime_glob + f"*-*-{int(length)}-*-*-*-*-PAM4-*-*-{int(baudrate/1e9)}.ini"
|
||||
# dataset_glob = data_datetime_glob + f"-*-*-{int(length)}-*-*-*-*-PAM4-*-*.ini"
|
||||
# datasets = [f for f in data_dir.glob(dataset_glob)]
|
||||
# if len(datasets) == 0:
|
||||
# continue
|
||||
# self.datasets[length] = {}
|
||||
# if len(datasets) > 1:
|
||||
# print(
|
||||
# f"multiple datasets found for [{length / 1000:.1f} km, {int(baudrate / 1e9)} GBd]. Using the newest dataset."
|
||||
# )
|
||||
# # get newest file from creation date
|
||||
# datasets.sort(key=lambda x: x.stat().st_ctime)
|
||||
# self.datasets[length][baudrate] = str(datasets[-1])
|
||||
|
||||
def load_dataset(self, dataset_path):
|
||||
if self.checkpoint_dict is None:
|
||||
raise ValueError("Model must be loaded before dataset")
|
||||
|
||||
if self.dataset_path is None:
|
||||
self.dataset_path = dataset_path
|
||||
elif self.dataset_path == dataset_path:
|
||||
return
|
||||
|
||||
symbols = self.checkpoint_dict["settings"]["data_settings"].symbols
|
||||
data_size = self.checkpoint_dict["settings"]["data_settings"].output_size
|
||||
dtype = getattr(torch, self.checkpoint_dict["settings"]["data_settings"].dtype)
|
||||
drop_first = self.checkpoint_dict["settings"]["data_settings"].drop_first
|
||||
randomise_polarisations = self.checkpoint_dict["settings"]["data_settings"].randomise_polarisations
|
||||
polarisations = self.checkpoint_dict["settings"]["data_settings"].polarisations
|
||||
num_symbols = None
|
||||
if self.debug:
|
||||
num_symbols = 1000
|
||||
|
||||
config_path = Path(dataset_path)
|
||||
|
||||
dataset = util.datasets.FiberRegenerationDataset(
|
||||
file_path=config_path,
|
||||
symbols=symbols,
|
||||
output_dim=data_size,
|
||||
drop_first=drop_first,
|
||||
dtype=dtype,
|
||||
real=not dtype.is_complex,
|
||||
randomise_polarisations=randomise_polarisations,
|
||||
polarisations=polarisations,
|
||||
num_symbols=num_symbols,
|
||||
# device="cuda" if torch.cuda.is_available() else "cpu",
|
||||
)
|
||||
|
||||
self.dataloader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=2**14, pin_memory=True, num_workers=24, prefetch_factor=8, shuffle=False
|
||||
)
|
||||
|
||||
return self.dataloader.dataset.orig_symbols
|
||||
# run model
|
||||
# return results as array: [fiber_in, fiber_out, fiber_out_noisy, regen_out]
|
||||
|
||||
def load_model(self, model_path: str | None = None):
|
||||
if model_path is None:
|
||||
self.model = None
|
||||
self.model_path = None
|
||||
self.checkpoint_dict = None
|
||||
return
|
||||
|
||||
path = Path(model_path)
|
||||
|
||||
if self.model_path is None:
|
||||
self.model_path = path
|
||||
elif path == self.model_path:
|
||||
return
|
||||
|
||||
self.dataset_path = None # reset dataset path, as the shape depends on the model
|
||||
|
||||
self.checkpoint_dict = torch.load(path, weights_only=True)
|
||||
dims = self.checkpoint_dict["model_kwargs"].pop("dims")
|
||||
self.model = models.regenerator(*dims, **self.checkpoint_dict["model_kwargs"])
|
||||
self.model.load_state_dict(self.checkpoint_dict["model_state_dict"])
|
||||
|
||||
## Model evaluation
|
||||
def run_model_evaluation(self, model_path: str, datasets: str | None = None):
|
||||
self.load_model(model_path)
|
||||
# iterate over datasets and osnr values:
|
||||
# load dataset, add noise, run model, return results
|
||||
# save results to file
|
||||
self.results = []
|
||||
|
||||
if datasets is not None:
|
||||
self.load_datasets_from_file(datasets)
|
||||
|
||||
n_datasets = len(self.datasets)
|
||||
for i, dataset_path in enumerate(self.datasets):
|
||||
conf = get_config(dataset_path)
|
||||
mpath = Path(model_path)
|
||||
model_base = mpath.stem
|
||||
print(f"({1+i}/{n_datasets}) Running model {model_base} with dataset {dataset_path.split('/')[-1]}")
|
||||
|
||||
results_path = self.build_path(
|
||||
dataset_path.split("/")[-1], parent_dir=Path(self.results_dir) / model_base
|
||||
)
|
||||
|
||||
orig_symbols = self.load_dataset(dataset_path)
|
||||
|
||||
data, loss = self.run_model()
|
||||
|
||||
metadata = {
|
||||
"model_path": model_path,
|
||||
"dataset_path": dataset_path,
|
||||
"loss": loss,
|
||||
"sps": conf["glova"]["sps"],
|
||||
"orig_symbols": orig_symbols
|
||||
# "config": conf,
|
||||
# "checkpoint_dict": self.checkpoint_dict,
|
||||
# "nos": self.dataloader.dataset.num_symbols,
|
||||
}
|
||||
|
||||
self.save_to_file(results_path, data, **metadata)
|
||||
self.results.append(str(results_path))
|
||||
|
||||
results_list_path = self.build_path("results_list", parent_dir=self.results_dir, timestamp="back")
|
||||
metadata = {"parameters": self.parameters}
|
||||
data = np.array(self.results, dtype="S")
|
||||
self.save_to_file(results_list_path, data, **metadata)
|
||||
|
||||
def run_model(self):
|
||||
loss = 0
|
||||
datas = []
|
||||
|
||||
self.model.eval()
|
||||
model = self.model.to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
with torch.no_grad():
|
||||
for batch in self.dataloader:
|
||||
x = batch["x_stacked"]
|
||||
y = batch["y_stacked"]
|
||||
fiber_in = batch["plot_target"]
|
||||
# fiber_out = batch["plot_clean"]
|
||||
fiber_out = batch["plot_data"]
|
||||
timestamp = batch["timestamp"]
|
||||
angle = batch["mean_angle"]
|
||||
x = x.to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
angle = angle.to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
regen = model(x, -angle)
|
||||
regen = regen.to("cpu")
|
||||
loss += util.complexNN.complex_mse_loss(regen, y, power=True).item()
|
||||
# shape: [batch_size, 4]
|
||||
plot_regen = regen[:, :2]
|
||||
plot_regen = plot_regen.view(plot_regen.shape[0], -1, 2)
|
||||
plot_regen = plot_regen[:, plot_regen.shape[1] // 2, :]
|
||||
|
||||
data_out = torch.cat(
|
||||
(
|
||||
fiber_in,
|
||||
fiber_out,
|
||||
# fiber_out_noisy,
|
||||
plot_regen,
|
||||
timestamp.view(-1, 1),
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
datas.append(data_out)
|
||||
|
||||
data_out = torch.cat(datas, dim=0).numpy()
|
||||
|
||||
return data_out, loss
|
||||
|
||||
## File I/O
|
||||
@staticmethod
|
||||
def save_to_file(path: str, data: np.ndarray, **metadata: dict):
|
||||
# create directory if it doesn't exist
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with h5py.File(path, "w") as outfile:
|
||||
outfile.create_dataset("data", data=data)
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, dict):
|
||||
value = json.dumps(model_runner.convert_arrays(value))
|
||||
outfile.attrs[key] = value
|
||||
|
||||
@staticmethod
|
||||
def convert_arrays(dict_in):
|
||||
"""
|
||||
convert ndarrays in (nested) dict to lists
|
||||
"""
|
||||
dict_out = {}
|
||||
for key, value in dict_in.items():
|
||||
if isinstance(value, dict):
|
||||
dict_out[key] = model_runner.convert_arrays(value)
|
||||
elif isinstance(value, np.ndarray):
|
||||
dict_out[key] = value.tolist()
|
||||
else:
|
||||
dict_out[key] = value
|
||||
return dict_out
|
||||
|
||||
@staticmethod
|
||||
def load_from_file(path: str):
|
||||
with h5py.File(path, "r") as infile:
|
||||
data = infile["data"][:]
|
||||
metadata = {}
|
||||
for key in infile.attrs.keys():
|
||||
if isinstance(infile.attrs[key], (str, bytes, bytearray)):
|
||||
try:
|
||||
metadata[key] = json.loads(infile.attrs[key])
|
||||
except json.JSONDecodeError:
|
||||
metadata[key] = infile.attrs[key]
|
||||
else:
|
||||
metadata[key] = infile.attrs[key]
|
||||
return data, metadata
|
||||
|
||||
## Utility functions
|
||||
@staticmethod
|
||||
def logrange(start, stop, num, endpoint=False):
|
||||
lower, upper = np.log10((start, stop))
|
||||
return np.logspace(lower, upper, num=num, endpoint=endpoint, base=10)
|
||||
|
||||
@staticmethod
|
||||
def build_path(
|
||||
*elements, parent_dir: str | Path | None = None, filetype="h5", timestamp: Literal["no", "front", "back"] = "no"
|
||||
):
|
||||
suffix = f".{filetype}" if not filetype.startswith(".") else filetype
|
||||
if timestamp != "no":
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
if timestamp == "front":
|
||||
elements = (ts, *elements)
|
||||
elif timestamp == "back":
|
||||
elements = (*elements, ts)
|
||||
path = "_".join(elements)
|
||||
path += suffix
|
||||
|
||||
if parent_dir is not None:
|
||||
path = Path(parent_dir) / path
|
||||
return path
|
||||
|
||||
@staticmethod
|
||||
def update_range(min, max, n_steps, log):
|
||||
if log:
|
||||
range = model_runner.logrange(min, max, n_steps, endpoint=True)
|
||||
else:
|
||||
range = np.linspace(min, max, n_steps, endpoint=True)
|
||||
return range
|
||||
|
||||
|
||||
class model_evaluation_result:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
length=None,
|
||||
baudrate=None,
|
||||
osnr=None,
|
||||
model_path=None,
|
||||
dataset_path=None,
|
||||
loss=None,
|
||||
sps=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.length = length
|
||||
self.baudrate = baudrate
|
||||
self.osnr = osnr
|
||||
self.model_path = model_path
|
||||
self.dataset_path = dataset_path
|
||||
self.loss = loss
|
||||
self.sps = sps
|
||||
|
||||
self.sers = None
|
||||
self.bers = None
|
||||
self.eye_stats = None
|
||||
|
||||
|
||||
class evaluator:
|
||||
def __init__(self, datasets: list[str]):
|
||||
"""
|
||||
datasets: iterable of dataset paths
|
||||
data_dir: directory containing datasets
|
||||
"""
|
||||
self.datasets = datasets
|
||||
self.results = []
|
||||
|
||||
def evaluate_datasets(self, plot=False):
|
||||
## iterate over datasets
|
||||
# load dataset
|
||||
for dataset in self.datasets:
|
||||
model, dataset_name = dataset.split("/")[-2:]
|
||||
print(f"\nEvaluating model {model} with dataset {dataset_name}")
|
||||
data, metadata = model_runner.load_from_file(dataset)
|
||||
result = model_evaluation_result(**metadata)
|
||||
|
||||
data = self.prepare_data(data, sps=metadata["sps"])
|
||||
|
||||
try:
|
||||
sym_x, sym_y = metadata["orig_symbols"]
|
||||
except (TypeError, KeyError, ValueError):
|
||||
sym_x, sym_y = None, None
|
||||
|
||||
self.evaluate_eye(data, result, title=dataset.split("/")[-1], plot=False)
|
||||
self.evaluate_ser_ber(data, result, sym_x, sym_y)
|
||||
print("BER:")
|
||||
self.print_dict(result.bers["regen"])
|
||||
print()
|
||||
print("SER:")
|
||||
self.print_dict(result.sers["regen"])
|
||||
print()
|
||||
|
||||
self.results.append(result)
|
||||
if plot:
|
||||
plt.show()
|
||||
|
||||
|
||||
def evaluate_eye(self, data, result, title=None, plot=False):
|
||||
eye = util.eye_diagram.eye_diagram(
|
||||
data,
|
||||
channel_names=[
|
||||
"fiber_in_x",
|
||||
"fiber_in_y",
|
||||
# "fiber_out_x",
|
||||
# "fiber_out_y",
|
||||
"fiber_out_x",
|
||||
"fiber_out_y",
|
||||
"regen_x",
|
||||
"regen_y",
|
||||
],
|
||||
)
|
||||
|
||||
eye.analyse()
|
||||
eye.plot(title=title or "Eye diagram", show=plot)
|
||||
|
||||
result.eye_stats = eye.eye_stats
|
||||
return eye.eye_stats
|
||||
...
|
||||
|
||||
def evaluate_ser_ber(self, data, result, sym_x=None, sym_y=None):
|
||||
if result.eye_stats is None:
|
||||
self.evaluate_eye(data, result)
|
||||
|
||||
symbols = []
|
||||
sers = {"fiber_out": {"x": None, "y": None}, "regen": {"x": None, "y": None}}
|
||||
bers = {"fiber_out": {"x": None, "y": None}, "regen": {"x": None, "y": None}}
|
||||
|
||||
for channel_data, stats in zip(data, result.eye_stats):
|
||||
timestamps = channel_data[0]
|
||||
dat = channel_data[1]
|
||||
|
||||
channel_name = stats["channel_name"]
|
||||
if stats["success"]:
|
||||
thresholds = stats["thresholds"]
|
||||
time_midpoint = stats["time_midpoint"]
|
||||
else:
|
||||
if channel_name.endswith("x"):
|
||||
thresholds = result.eye_stats[0]["thresholds"]
|
||||
time_midpoint = result.eye_stats[0]["time_midpoint"]
|
||||
elif channel_name.endswith("y"):
|
||||
thresholds = result.eye_stats[1]["thresholds"]
|
||||
time_midpoint = result.eye_stats[1]["time_midpoint"]
|
||||
else:
|
||||
levels = np.linspace(np.min(dat), np.max(dat), 4)
|
||||
thresholds = util.eye_diagram.eye_diagram.calculate_thresholds(levels)
|
||||
time_midpoint = 1.0
|
||||
|
||||
# time_offset = time_midpoint - 0.5
|
||||
# # time_offset = 0
|
||||
|
||||
# index_offset = np.argmin(np.abs((timestamps - time_offset) % 1.0))
|
||||
|
||||
nos = len(timestamps) // result.sps
|
||||
|
||||
# idx = np.arange(index_offset, len(timestamps), result.sps).astype(int)
|
||||
|
||||
# if time_offset < 0:
|
||||
# idx = np.insert(idx, 0, 0)
|
||||
|
||||
idx = list(range(0,len(timestamps),result.sps))
|
||||
|
||||
idx = idx[:nos]
|
||||
|
||||
data_sampled = dat[idx]
|
||||
detected_symbols = self.detect_symbols(data_sampled, thresholds)
|
||||
|
||||
symbols.append({"channel_name": channel_name, "symbols": detected_symbols})
|
||||
|
||||
symbols_x_gt = sym_x or symbols[0]["symbols"]
|
||||
symbols_y_gt = sym_y or symbols[1]["symbols"]
|
||||
|
||||
symbols_x_fiber_out = symbols[2]["symbols"]
|
||||
symbols_y_fiber_out = symbols[3]["symbols"]
|
||||
|
||||
symbols_x_regen = symbols[4]["symbols"]
|
||||
symbols_y_regen = symbols[5]["symbols"]
|
||||
|
||||
sers["fiber_out"]["x"], bers["fiber_out"]["x"] = self.calculate_ser_ber(symbols_x_gt, symbols_x_fiber_out)
|
||||
sers["fiber_out"]["y"], bers["fiber_out"]["y"] = self.calculate_ser_ber(symbols_y_gt, symbols_y_fiber_out)
|
||||
sers["regen"]["x"], bers["regen"]["x"] = self.calculate_ser_ber(symbols_x_gt, symbols_x_regen)
|
||||
sers["regen"]["y"], bers["regen"]["y"] = self.calculate_ser_ber(symbols_y_gt, symbols_y_regen)
|
||||
|
||||
result.sers = sers
|
||||
result.bers = bers
|
||||
|
||||
@staticmethod
|
||||
def calculate_ser_ber(symbols_gt, symbols):
|
||||
# levels = 4
|
||||
# symbol difference -> bit error count
|
||||
# |rx - tx| = 0 -> 0
|
||||
# |rx - tx| = 1 -> 1
|
||||
# |rx - tx| = 2 -> 2
|
||||
# |rx - tx| = 3 -> 1
|
||||
# assuming gray coding -> 0: 00, 1: 01, 2: 11, 3: 10
|
||||
bec_map = {0: 0, 1: 1, 2: 2, 3: 1, np.nan: 2}
|
||||
|
||||
ser = {}
|
||||
ber = {}
|
||||
ser["n_symbols"] = len(symbols_gt)
|
||||
ser["n_errors"] = np.sum(symbols != symbols_gt)
|
||||
ser["total"] = float(ser["n_errors"] / ser["n_symbols"])
|
||||
|
||||
bec = np.vectorize(bec_map.get)(np.abs(symbols - symbols_gt))
|
||||
bit_errors = np.sum(bec)
|
||||
|
||||
ber["n_bits"] = len(symbols_gt) * 2
|
||||
ber["n_errors"] = bit_errors
|
||||
ber["total"] = float(ber["n_errors"] / ber["n_bits"])
|
||||
|
||||
return ser, ber
|
||||
|
||||
@staticmethod
|
||||
def print_dict(d: dict, indent=2, logarithmic=False, level=0):
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
print(f"{' ' * indent * level}{key}:")
|
||||
evaluator.print_dict(value, indent=indent, logarithmic=logarithmic, level=level + 1)
|
||||
else:
|
||||
if isinstance(value, float):
|
||||
if logarithmic:
|
||||
if value == 0:
|
||||
value = -np.inf
|
||||
else:
|
||||
value = np.log10(value)
|
||||
print(f"{' ' * indent * level}{key}: {value:.2e}\t", end="")
|
||||
else:
|
||||
print(f"{' ' * indent * level}{key}: {value}\t", end="")
|
||||
print()
|
||||
|
||||
@staticmethod
|
||||
def detect_symbols(samples, thresholds=None):
|
||||
thresholds = (1 / 6, 3 / 6, 5 / 6) if thresholds is None else thresholds
|
||||
thresholds = (-np.inf, *thresholds, np.inf)
|
||||
symbols = np.digitize(samples, thresholds) - 1
|
||||
return symbols
|
||||
|
||||
@staticmethod
|
||||
def prepare_data(data, sps=None):
|
||||
data = data.transpose(1, 0)
|
||||
timestamps = data[-1].real
|
||||
data = data[:-1]
|
||||
if sps is not None:
|
||||
timestamps /= sps
|
||||
|
||||
# data = np.stack(
|
||||
# (
|
||||
# *data[0:2], # fiber_in_x, fiber_in_y
|
||||
# # *data_[2:4], # fiber_out_x, fiber_out_y
|
||||
# *data[4:6], # fiber_out_noisy_x, fiber_out_noisy_y
|
||||
# *data[6:8], # regen_out_x, regen_out_y
|
||||
# ),
|
||||
# axis=0,
|
||||
# )
|
||||
|
||||
data_eye = []
|
||||
for channel_values in data:
|
||||
channel_values = np.square(np.abs(channel_values))
|
||||
data_eye.append(np.stack((timestamps, channel_values), axis=0))
|
||||
|
||||
data_eye = np.stack(data_eye, axis=0)
|
||||
|
||||
return data_eye
|
||||
|
||||
|
||||
def generate_data(parameters, runner=None):
|
||||
runner = runner or model_runner()
|
||||
for param in parameters:
|
||||
runner.update_parameter_range(*param)
|
||||
runner.generate_iterations()
|
||||
print(f"{runner.iter.length} parameter combinations")
|
||||
runner.generate_datasets()
|
||||
|
||||
return runner
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_path = ".models/best_20250110_191149.tar" # D 17, OSNR 100, delta_beta 0.14, baud 10e9
|
||||
|
||||
parameters = (
|
||||
# name, config keys, (min, max), n_steps, log
|
||||
# ("D", ("fiber", "d"), (28,30), 3, False),
|
||||
# ("S", ("fiber", "s"), (0, 0.058), 2, False),
|
||||
("OSNR", ("signal", "osnr"), (20, 40), 5, False),
|
||||
# ("PMD", ("fiber", "max_delta_beta"), (0, 0.28), 3, False),
|
||||
# ("Baud", ("glova", "symbolrate"), (10e9, 100e9), 3, True),
|
||||
)
|
||||
|
||||
datasets = None
|
||||
results = None
|
||||
|
||||
# datasets = "tolerance_results/datasets/datasets_list_20250110_223337.h5"
|
||||
results = "tolerance_results/datasets/results_list_20250110_232639.h5"
|
||||
|
||||
runner = model_runner()
|
||||
# generate_data(parameters, runner)
|
||||
|
||||
|
||||
if results is None:
|
||||
if datasets is None:
|
||||
generate_data(parameters, runner)
|
||||
else:
|
||||
runner.load_datasets_from_file(datasets)
|
||||
print(f"{len(runner.datasets)} loaded")
|
||||
|
||||
runner.run_model_evaluation(model_path)
|
||||
else:
|
||||
runner.load_results_from_file(results)
|
||||
|
||||
# print(runner.parameters)
|
||||
# print(runner.results)
|
||||
|
||||
eval = evaluator(runner.results)
|
||||
eval.evaluate_datasets(plot=True)
|
||||
@@ -481,6 +481,15 @@ class Identity(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
class phase_shift(nn.Module):
|
||||
def __init__(self, size):
|
||||
super(phase_shift, self).__init__()
|
||||
self.size = size
|
||||
self.phase = nn.Parameter(torch.rand(size))
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.exp(1j*self.phase)
|
||||
|
||||
|
||||
class PowRot(nn.Module):
|
||||
@@ -531,19 +540,19 @@ def angle_mse_loss(x: torch.Tensor, target: torch.Tensor):
|
||||
|
||||
class EOActivation(nn.Module):
|
||||
def __init__(self, size=None):
|
||||
# 10.1109/SiPhotonics60897.2024.10543376
|
||||
# 10.1109/JSTQE.2019.2930455
|
||||
super(EOActivation, self).__init__()
|
||||
if size is None:
|
||||
raise ValueError("Size must be specified")
|
||||
self.size = size
|
||||
self.alpha = nn.Parameter(torch.ones(size))
|
||||
self.V_bias = nn.Parameter(torch.ones(size))
|
||||
self.gain = nn.Parameter(torch.ones(size))
|
||||
self.alpha = nn.Parameter(torch.rand(size))
|
||||
self.V_bias = nn.Parameter(torch.rand(size))
|
||||
self.gain = nn.Parameter(torch.rand(size))
|
||||
# if bias:
|
||||
# self.phase_bias = nn.Parameter(torch.zeros(size))
|
||||
# else:
|
||||
# self.register_buffer("phase_bias", torch.zeros(size))
|
||||
self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
|
||||
# self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
|
||||
self.register_buffer("responsivity", torch.ones(size)*0.9)
|
||||
self.register_buffer("V_pi", torch.ones(size)*3)
|
||||
|
||||
@@ -551,17 +560,17 @@ class EOActivation(nn.Module):
|
||||
|
||||
def reset_weights(self):
|
||||
if "alpha" in self._parameters:
|
||||
self.alpha.data = torch.ones(self.size)*0.5
|
||||
self.alpha.data = torch.rand(self.size)
|
||||
if "V_pi" in self._parameters:
|
||||
self.V_pi.data = torch.ones(self.size)*3
|
||||
self.V_pi.data = torch.rand(self.size)*3
|
||||
if "V_bias" in self._parameters:
|
||||
self.V_bias.data = torch.zeros(self.size)
|
||||
self.V_bias.data = torch.randn(self.size)
|
||||
if "gain" in self._parameters:
|
||||
self.gain.data = torch.ones(self.size)
|
||||
self.gain.data = torch.rand(self.size)
|
||||
if "responsivity" in self._parameters:
|
||||
self.responsivity.data = torch.ones(self.size)*0.9
|
||||
if "bias" in self._parameters:
|
||||
self.phase_bias.data = torch.zeros(self.size)
|
||||
# if "bias" in self._parameters:
|
||||
# self.phase_bias.data = torch.zeros(self.size)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
|
||||
@@ -570,12 +579,11 @@ class EOActivation(nn.Module):
|
||||
return (
|
||||
1j
|
||||
* torch.sqrt(1 - self.alpha)
|
||||
* torch.exp(-0.5j * (intermediate + self.phase_bias))
|
||||
* torch.exp(-0.5j * intermediate)
|
||||
* torch.cos(0.5 * intermediate)
|
||||
* x
|
||||
)
|
||||
|
||||
|
||||
class Pow(nn.Module):
|
||||
"""
|
||||
implements the activation function
|
||||
@@ -716,6 +724,7 @@ __all__ = [
|
||||
MZISingle,
|
||||
EOActivation,
|
||||
photodiode,
|
||||
phase_shift,
|
||||
# SaturableAbsorberLambertW,
|
||||
# SaturableAbsorber,
|
||||
# SpreadLayer,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from pathlib import Path
|
||||
import h5py
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
@@ -24,8 +25,22 @@ import multiprocessing as mp
|
||||
# def __len__(self):
|
||||
# return len(self.indices)
|
||||
|
||||
def load_from_file(datapath):
|
||||
if str(datapath).endswith('.h5'):
|
||||
symbols = None
|
||||
with h5py.File(datapath, "r") as infile:
|
||||
data = infile["data"][:]
|
||||
try:
|
||||
symbols = infile["symbols"][:]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
symbols = None
|
||||
data = np.load(datapath)
|
||||
return data, symbols
|
||||
|
||||
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
|
||||
|
||||
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=1, device=None, dtype=None):
|
||||
filepath = Path(config_path)
|
||||
filepath = filepath.parent.glob(filepath.name)
|
||||
config = configparser.ConfigParser()
|
||||
@@ -40,15 +55,21 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
|
||||
|
||||
if symbols is None:
|
||||
symbols = int(config["glova"]["nos"]) - skipfirst
|
||||
|
||||
data, orig_symbols = load_from_file(datapath)
|
||||
|
||||
data = np.load(datapath)[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)]
|
||||
data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)]
|
||||
orig_symbols = orig_symbols[skipfirst:symbols+skipfirst]
|
||||
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps))
|
||||
|
||||
if normalize:
|
||||
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
||||
a, b, c, d = np.square(data.T)
|
||||
a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
|
||||
data = np.sqrt(np.array([a, b, c, d]).T)
|
||||
data *= np.sqrt(normalize)
|
||||
|
||||
# if normalize:
|
||||
# # square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
||||
# a, b, c, d = data.T
|
||||
# a, b, c, d = a - np.min(np.abs(a)), b - np.min(np.abs(b)), c - np.min(np.abs(c)), d - np.min(np.abs(d))
|
||||
# a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
|
||||
# data = np.array([a, b, c, d]).T
|
||||
|
||||
if real:
|
||||
data = np.abs(data)
|
||||
@@ -59,7 +80,7 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
|
||||
|
||||
data = torch.tensor(data, device=device, dtype=dtype)
|
||||
|
||||
return data, config
|
||||
return data, config, orig_symbols
|
||||
|
||||
|
||||
def roll_along(arr, shifts, dim):
|
||||
@@ -114,7 +135,8 @@ class FiberRegenerationDataset(Dataset):
|
||||
dtype: torch.dtype = None,
|
||||
real: bool = False,
|
||||
device=None,
|
||||
osnr: float = None,
|
||||
# osnr: float|None = None,
|
||||
polarisations = None,
|
||||
randomise_polarisations: bool = False,
|
||||
repeat_randoms: int = 1,
|
||||
**kwargs,
|
||||
@@ -151,65 +173,50 @@ class FiberRegenerationDataset(Dataset):
|
||||
|
||||
self.randomise_polarisations = randomise_polarisations
|
||||
|
||||
faux = kwargs.pop("faux", False)
|
||||
|
||||
if faux:
|
||||
data_raw = np.array(
|
||||
[[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)],
|
||||
dtype=np.complex128,
|
||||
data_raw = None
|
||||
self.config = None
|
||||
files = []
|
||||
self.orig_symbols = None
|
||||
for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]:
|
||||
data, config, orig_syms = load_data(
|
||||
file_path,
|
||||
skipfirst=drop_first,
|
||||
symbols=kwargs.get("num_symbols", None),
|
||||
real=real,
|
||||
normalize=1000,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
|
||||
timestamps = torch.arange(12800)
|
||||
|
||||
data_raw = torch.concatenate([data_raw, timestamps.reshape(-1, 1)], axis=-1)
|
||||
|
||||
self.config = {
|
||||
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
|
||||
"glova": {"sps": 128},
|
||||
}
|
||||
else:
|
||||
data_raw = None
|
||||
self.config = None
|
||||
files = []
|
||||
for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]:
|
||||
data, config = load_data(
|
||||
file_path,
|
||||
skipfirst=drop_first,
|
||||
symbols=kwargs.get("num_symbols", None),
|
||||
real=real,
|
||||
normalize=True,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
if data_raw is None:
|
||||
data_raw = data
|
||||
if orig_syms is not None:
|
||||
if self.orig_symbols is None:
|
||||
self.orig_symbols = orig_syms
|
||||
else:
|
||||
data_raw = torch.cat([data_raw, data], dim=0)
|
||||
if self.config is None:
|
||||
self.config = config
|
||||
else:
|
||||
assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same"
|
||||
files.append(config["data"]["file"].strip('"'))
|
||||
self.config["data"]["file"] = str(files)
|
||||
self.orig_symbols = np.concat((self.orig_symbols, orig_syms), axis=-1)
|
||||
|
||||
if data_raw is None:
|
||||
data_raw = data
|
||||
else:
|
||||
data_raw = torch.cat([data_raw, data], dim=0)
|
||||
if self.config is None:
|
||||
self.config = config
|
||||
else:
|
||||
assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same"
|
||||
files.append(config["data"]["file"].strip('"'))
|
||||
self.config["data"]["file"] = str(files)
|
||||
|
||||
# if polarisations is not None:
|
||||
# self.angles = torch.tensor(polarisations).repeat(len(data_raw), 1)
|
||||
# for i, angle in enumerate(torch.tensor(np.array(polarisations))):
|
||||
# data_raw_copy = data_raw.clone()
|
||||
# if angle == 0:
|
||||
# continue
|
||||
# sine = torch.sin(angle)
|
||||
# cosine = torch.cos(angle)
|
||||
# data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine
|
||||
# data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine
|
||||
# if i == 0:
|
||||
# data_raw = data_raw_copy
|
||||
# else:
|
||||
# data_raw = torch.cat([data_raw, data_raw_copy], dim=0)
|
||||
# if polarisations is not None:
|
||||
# data_raw_clone = data_raw.clone()
|
||||
# # rotate the polarisation by 180 degrees
|
||||
# data_raw_clone[2, :] *= -1
|
||||
# data_raw_clone[3, :] *= -1
|
||||
# data_raw = torch.cat([data_raw, data_raw_clone], dim=0)
|
||||
|
||||
self.polarisations = bool(polarisations)
|
||||
|
||||
self.device = data_raw.device
|
||||
|
||||
self.samples_per_symbol = int(self.config["glova"]["sps"])
|
||||
# self.num_symbols = int(self.config["glova"]["nos"])
|
||||
self.samples_per_slice = int(symbols * self.samples_per_symbol)
|
||||
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
|
||||
|
||||
@@ -290,6 +297,34 @@ class FiberRegenerationDataset(Dataset):
|
||||
fiber_in = torch.cat([fiber_in, timestamps.unsqueeze(0)], dim=0)
|
||||
fiber_out = torch.cat([fiber_out, timestamps.unsqueeze(0)], dim=0)
|
||||
|
||||
# fiber_out: [E_out_x, E_out_y, timestamps]
|
||||
|
||||
# add noise related to amplification necessary due to splitting of the signal
|
||||
gain_lin = output_dim*2
|
||||
edfa_nf = float(self.config["signal"]["edfa_nf"])
|
||||
nf_lin = 10**(edfa_nf/10)
|
||||
f0 = float(self.config["glova"]["f0"])
|
||||
|
||||
noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9
|
||||
|
||||
noise = torch.randn_like(fiber_out[:2, :])
|
||||
noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1))
|
||||
noise = noise * torch.sqrt(noise_add / noise_power)
|
||||
fiber_out[:2, :] += noise
|
||||
|
||||
|
||||
|
||||
|
||||
# if osnr is None:
|
||||
# noisy = fiber_out[:2, :]
|
||||
# else:
|
||||
# noisy = self.add_noise(fiber_out[:2, :], osnr)
|
||||
|
||||
# fiber_out = torch.cat([fiber_out, noisy], dim=0)
|
||||
|
||||
# fiber_out: [E_out_x, E_out_y, timestamps, E_out_x_noisy, E_out_y_noisy]
|
||||
|
||||
|
||||
if repeat_randoms > 1:
|
||||
fiber_in = fiber_in.repeat(1, 1, repeat_randoms)
|
||||
fiber_out = fiber_out.repeat(1, 1, repeat_randoms)
|
||||
@@ -298,28 +333,34 @@ class FiberRegenerationDataset(Dataset):
|
||||
repeat_randoms = 1
|
||||
|
||||
if self.randomise_polarisations:
|
||||
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms), 2) * torch.pi
|
||||
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms, device=fiber_out.device), 2) * torch.pi
|
||||
# start_angle = torch.rand(1) * 2 * torch.pi
|
||||
# angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
|
||||
# self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi
|
||||
else:
|
||||
angles = torch.zeros(data_raw.shape[-1])
|
||||
angles = torch.zeros(data_raw.shape[-1], device=fiber_out.device)
|
||||
|
||||
sin = torch.sin(angles)
|
||||
cos = torch.cos(angles)
|
||||
rot = torch.stack([torch.stack([cos, -sin], dim=1), torch.stack([sin, cos], dim=1)], dim=2)
|
||||
data_rot = torch.bmm(fiber_out[:2, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
|
||||
# data_rot_noisy = torch.bmm(fiber_out[3:5, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
|
||||
fiber_out = torch.cat((fiber_out, data_rot), dim=0)
|
||||
fiber_out = torch.cat([fiber_out, angles.unsqueeze(0)], dim=0)
|
||||
|
||||
if osnr is not None:
|
||||
popt = torch.mean(fiber_out[:2, :, :].abs().flatten(), dim=-1)
|
||||
noise = torch.randn_like(fiber_out[:2, :, :])
|
||||
pn = torch.mean(noise.abs().flatten(), dim=-1)
|
||||
noise = noise * (popt / pn) * 10 ** (-osnr / 20)
|
||||
fiber_out[:2, :, :] = torch.add(fiber_out[:2, :, :], noise)
|
||||
# fiber_in:
|
||||
# 0 E_in_x,
|
||||
# 1 E_in_y,
|
||||
# 2 timestamps
|
||||
|
||||
# fiber_out:
|
||||
# 0 E_out_x,
|
||||
# 1 E_out_y,
|
||||
# 2 timestamps,
|
||||
# 3 E_out_x_rot,
|
||||
# 4 E_out_y_rot,
|
||||
# 5 angle
|
||||
|
||||
|
||||
|
||||
|
||||
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
||||
@@ -349,6 +390,22 @@ class FiberRegenerationDataset(Dataset):
|
||||
|
||||
def __len__(self):
|
||||
return self.fiber_in.shape[0]
|
||||
|
||||
def add_noise(self, data, osnr):
|
||||
osnr_lin = 10**(osnr/10)
|
||||
popt = torch.mean(data.abs().square().squeeze(), dim=-1)
|
||||
noise = torch.randn_like(data)
|
||||
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
|
||||
|
||||
mult = torch.sqrt(popt/(pn*osnr_lin))
|
||||
mult = mult * torch.eye(popt.shape[0], device=mult.device)
|
||||
mult = mult.to(dtype=noise.dtype)
|
||||
|
||||
noise = mult @ noise
|
||||
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
|
||||
noisy = data + noise
|
||||
return noisy
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, slice):
|
||||
@@ -357,14 +414,19 @@ class FiberRegenerationDataset(Dataset):
|
||||
# fiber in: [E_in_x, E_in_y, timestamps]
|
||||
# fiber out: [E_out_x, E_out_y, timestamps, E_out_x_rot, E_out_y_rot, angle]
|
||||
|
||||
# if self.polarisations:
|
||||
output_dim = self.output_dim // 2
|
||||
self.output_dim = output_dim * 2
|
||||
|
||||
fiber_in = self.fiber_in[idx].squeeze()
|
||||
fiber_out = self.fiber_out[idx].squeeze()
|
||||
|
||||
fiber_in = fiber_in[..., : fiber_in.shape[-1] // self.output_dim * self.output_dim]
|
||||
fiber_out = fiber_out[..., : fiber_out.shape[-1] // self.output_dim * self.output_dim]
|
||||
fiber_in = fiber_in[..., : fiber_in.shape[-1] // output_dim * output_dim]
|
||||
fiber_out = fiber_out[..., : fiber_out.shape[-1] // output_dim * output_dim]
|
||||
|
||||
fiber_in = fiber_in.view(fiber_in.shape[0], output_dim, -1)
|
||||
fiber_out = fiber_out.view(fiber_out.shape[0], output_dim, -1)
|
||||
|
||||
fiber_in = fiber_in.view(fiber_in.shape[0], self.output_dim, -1)
|
||||
fiber_out = fiber_out.view(fiber_out.shape[0], self.output_dim, -1)
|
||||
|
||||
# data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
|
||||
|
||||
@@ -372,11 +434,36 @@ class FiberRegenerationDataset(Dataset):
|
||||
|
||||
# angle = self.angles[idx]
|
||||
|
||||
center_angle = fiber_out[5, self.output_dim // 2, 0]
|
||||
# fiber_in:
|
||||
# 0 E_in_x,
|
||||
# 1 E_in_y,
|
||||
# 2 timestamps
|
||||
|
||||
# fiber_out:
|
||||
# 0 E_out_x,
|
||||
# 1 E_out_y,
|
||||
# 2 timestamps,
|
||||
# 3 E_out_x_rot,
|
||||
# 4 E_out_y_rot,
|
||||
# 5 angle
|
||||
|
||||
center_angle = fiber_out[0, output_dim // 2, 0]
|
||||
angles = fiber_out[5, :, 0]
|
||||
plot_data = fiber_out[:2, self.output_dim // 2, 0].detach().clone()
|
||||
plot_data_rot = fiber_out[3:5, self.output_dim // 2, 0].detach().clone()
|
||||
data = fiber_out[3:5, :, 0]
|
||||
plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone()
|
||||
plot_data_rot = fiber_out[3:5, output_dim // 2, 0].detach().clone()
|
||||
data = fiber_out[0:2, :, 0]
|
||||
# fiber_out_plot_clean = fiber_out[:2, output_dim // 2, 0].detach().clone()
|
||||
|
||||
|
||||
# if self.polarisations:
|
||||
# rot = int(np.random.randint(2)*2-1)
|
||||
# pol_flipped_data[0:1, :] = rot*data[0, :]
|
||||
# pol_flipped_data[1, :] = rot*data[1, :]
|
||||
# plot_data_rot[0] = rot*plot_data_rot[0]
|
||||
# plot_data_rot[1] = rot*plot_data_rot[1]
|
||||
# center_angle = center_angle + (rot - 1) * torch.pi/2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
|
||||
# angles = angles + (rot - 1) * torch.pi/2
|
||||
|
||||
|
||||
# if self.randomise_polarisations:
|
||||
# data = data.mT
|
||||
@@ -389,16 +476,27 @@ class FiberRegenerationDataset(Dataset):
|
||||
# angle = torch.zeros_like(angle)
|
||||
|
||||
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter)
|
||||
angle_data = fiber_out[:2, :, :].reshape(2, -1).mean(dim=1).repeat(1, self.output_dim)
|
||||
angle_data2 = self.complex_max(fiber_out[:2, :, :].reshape(2, -1)).repeat(1, self.output_dim)
|
||||
# angle_data = fiber_out[:2, :, :].reshape(2, -1).mean(dim=1).repeat(1, output_dim)
|
||||
# angle_data2 = self.complex_max(fiber_out[:2, :, :].reshape(2, -1)).repeat(1, output_dim)
|
||||
# sop = self.polarimeter(plot_data)
|
||||
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
|
||||
# angle = data_slice[1, 3, self.output_dim // 2, 0].real
|
||||
target = fiber_in[:2, self.output_dim // 2, 0]
|
||||
plot_target = fiber_in[:2, self.output_dim // 2, 0].detach().clone()
|
||||
target_timestamp = fiber_in[2, self.output_dim // 2, 0].real
|
||||
target = fiber_in[:2, output_dim // 2, 0]
|
||||
plot_target = fiber_in[:2, output_dim // 2, 0].detach().clone()
|
||||
target_timestamp = fiber_in[2, output_dim // 2, 0].real
|
||||
...
|
||||
|
||||
if self.polarisations:
|
||||
rot = int(np.random.randint(2)*2-1)
|
||||
data = rot*data
|
||||
target = rot*target
|
||||
plot_data_rot = rot*plot_data_rot
|
||||
center_angle = center_angle + (rot - 1) * torch.pi/2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
|
||||
angles = angles + (rot - 1) * torch.pi/2
|
||||
|
||||
pol_flipped_data = -data
|
||||
pol_flipped_target = -target
|
||||
|
||||
# data_timestamps = data[-1,:].real
|
||||
# data = data[:-1, :]
|
||||
# target_timestamp = target[-1].real
|
||||
@@ -407,13 +505,15 @@ class FiberRegenerationDataset(Dataset):
|
||||
|
||||
# transpose to interleave the x and y data in the output tensor
|
||||
data = data.transpose(0, 1).flatten().squeeze()
|
||||
angle_data = angle_data.transpose(0, 1).flatten().squeeze()
|
||||
angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
|
||||
pol_flipped_data = pol_flipped_data.transpose(0, 1).flatten().squeeze()
|
||||
# angle_data = angle_data.transpose(0, 1).flatten().squeeze()
|
||||
# angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
|
||||
center_angle = center_angle.flatten().squeeze()
|
||||
angles = angles.flatten().squeeze()
|
||||
# data_timestamps = data_timestamps.flatten().squeeze()
|
||||
# target = target.transpose(0,1).flatten().squeeze()
|
||||
target = target.flatten().squeeze()
|
||||
pol_flipped_target = pol_flipped_target.flatten().squeeze()
|
||||
target_timestamp = target_timestamp.flatten().squeeze()
|
||||
plot_target = plot_target.flatten().squeeze()
|
||||
plot_data = plot_data.flatten().squeeze()
|
||||
@@ -421,17 +521,22 @@ class FiberRegenerationDataset(Dataset):
|
||||
|
||||
return {
|
||||
"x": data,
|
||||
"x_flipped": pol_flipped_data,
|
||||
"x_stacked": torch.cat([data, pol_flipped_data], dim=-1),
|
||||
"y": target,
|
||||
"center_angle": center_angle,
|
||||
"angles": angles,
|
||||
"y_flipped": pol_flipped_target,
|
||||
"y_stacked": torch.cat([target, pol_flipped_target], dim=-1),
|
||||
# "center_angle": center_angle,
|
||||
# "angles": angles,
|
||||
"mean_angle": angles.mean(),
|
||||
# "sop": sop,
|
||||
"angle_data": angle_data,
|
||||
"angle_data2": angle_data2,
|
||||
# "angle_data": angle_data,
|
||||
# "angle_data2": angle_data2,
|
||||
"timestamp": target_timestamp,
|
||||
"plot_target": plot_target,
|
||||
"plot_data": plot_data,
|
||||
"plot_data_rot": plot_data_rot,
|
||||
# "plot_clean": fiber_out_plot_clean,
|
||||
}
|
||||
|
||||
def complex_max(self, data, dim=-1):
|
||||
|
||||
@@ -82,7 +82,6 @@ class eye_diagram:
|
||||
self.vertical_bins = vertical_bins
|
||||
self.multi_threaded = multithreaded
|
||||
self.eye_built = False
|
||||
self.analyse()
|
||||
|
||||
def generate_eye_data(self):
|
||||
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
|
||||
@@ -126,6 +125,7 @@ class eye_diagram:
|
||||
rows = int(np.ceil(self.channels / cols))
|
||||
fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False)
|
||||
fig.suptitle(title)
|
||||
fig.tight_layout()
|
||||
ax = np.atleast_1d(ax).transpose().flatten()
|
||||
for i in range(self.channels):
|
||||
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}")
|
||||
@@ -147,19 +147,21 @@ class eye_diagram:
|
||||
yspan = ymax - ymin
|
||||
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
|
||||
if stats and self.eye_stats[i]["success"]:
|
||||
# add min_area above the plot
|
||||
ax[i].annotate(
|
||||
f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
|
||||
xy=(0.05, ymax + 0.05 * yspan),
|
||||
# xycoords="axes fraction",
|
||||
ha="left",
|
||||
va="center",
|
||||
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||
)
|
||||
# # add min_area above the plot
|
||||
# ax[i].annotate(
|
||||
# f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
|
||||
# xy=(0.05, ymax + 0.05 * yspan),
|
||||
# # xycoords="axes fraction",
|
||||
# ha="left",
|
||||
# va="center",
|
||||
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||
# )
|
||||
|
||||
if all_stats:
|
||||
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
|
||||
ax[i].set_yticks(self.eye_stats[i]["levels"])
|
||||
y_ticks = (*self.eye_stats[i]["levels"],*self.eye_stats[i]["thresholds"])
|
||||
# y_ticks = np.sort(y_ticks)
|
||||
ax[i].set_yticks(y_ticks)
|
||||
# add arrows for amplitudes
|
||||
for j in range(len(self.eye_stats[i]["amplitudes"])):
|
||||
ax[i].annotate(
|
||||
@@ -193,35 +195,35 @@ class eye_diagram:
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
# add arrows for eye widths
|
||||
for j in range(len(self.eye_stats[i]["widths"])):
|
||||
try:
|
||||
left = np.max(self.eye_stats[i]["time_clusters"][j][0])
|
||||
right = np.min(self.eye_stats[i]["time_clusters"][j][1])
|
||||
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
||||
# for j in range(len(self.eye_stats[i]["widths"])):
|
||||
# try:
|
||||
# left = np.max(self.eye_stats[i]["time_clusters"][j][0])
|
||||
# right = np.min(self.eye_stats[i]["time_clusters"][j][1])
|
||||
# vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
||||
|
||||
ax[i].annotate(
|
||||
"",
|
||||
xy=(left, vertical),
|
||||
xytext=(right, vertical),
|
||||
arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
||||
)
|
||||
ax[i].annotate(
|
||||
f"{self.eye_stats[i]['widths'][j]:.2e}",
|
||||
xy=((left + right) / 2 - 0.15, vertical + 0.01),
|
||||
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||
)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
# ax[i].annotate(
|
||||
# "",
|
||||
# xy=(left, vertical),
|
||||
# xytext=(right, vertical),
|
||||
# arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
||||
# )
|
||||
# ax[i].annotate(
|
||||
# f"{self.eye_stats[i]['widths'][j]:.2e}",
|
||||
# xy=((left + right) / 2 - 0.15, vertical + 0.01),
|
||||
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||
# )
|
||||
# except (ValueError, IndexError):
|
||||
# pass
|
||||
|
||||
# add area
|
||||
for j in range(len(self.eye_stats[i]["areas"])):
|
||||
horizontal = self.eye_stats[i]["time_midpoint"]
|
||||
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
||||
ax[i].annotate(
|
||||
f"{self.eye_stats[i]['areas'][j]:.2e}",
|
||||
xy=(horizontal + 0.035, vertical - 0.07),
|
||||
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||
)
|
||||
# # add area
|
||||
# for j in range(len(self.eye_stats[i]["areas"])):
|
||||
# horizontal = self.eye_stats[i]["time_midpoint"]
|
||||
# vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
||||
# ax[i].annotate(
|
||||
# f"{self.eye_stats[i]['areas'][j]:.2e}",
|
||||
# xy=(horizontal + 0.035, vertical - 0.07),
|
||||
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||
# )
|
||||
|
||||
fig.tight_layout()
|
||||
|
||||
@@ -229,6 +231,12 @@ class eye_diagram:
|
||||
plt.show()
|
||||
return fig
|
||||
|
||||
@staticmethod
|
||||
def calculate_thresholds(levels):
|
||||
ret = np.cumsum(levels, dtype=float)
|
||||
ret[2:] = ret[2:] - ret[:-2]
|
||||
return ret[1:]/2
|
||||
|
||||
def analyse_single(self, data, index):
|
||||
warnings.filterwarnings("error")
|
||||
eye_stats = {}
|
||||
@@ -238,12 +246,15 @@ class eye_diagram:
|
||||
|
||||
time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels)
|
||||
|
||||
eye_stats["time_midpoint"] = (time_bounds[0] + time_bounds[1]) / 2
|
||||
eye_stats["time_midpoint_calc"] = (time_bounds[0] + time_bounds[1]) / 2
|
||||
eye_stats["time_midpoint"] = 1.0
|
||||
|
||||
eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels(
|
||||
data, approx_levels, time_bounds
|
||||
)
|
||||
|
||||
eye_stats["thresholds"] = self.calculate_thresholds(eye_stats["levels"])
|
||||
|
||||
eye_stats["amplitudes"] = np.diff(eye_stats["levels"])
|
||||
|
||||
eye_stats["heights"] = eye_diagram.calculate_eye_heights(
|
||||
@@ -260,22 +271,23 @@ class eye_diagram:
|
||||
# if not (np.max(eye_stats['time_clusters'][j][0]) < eye_stats["time_midpoint"] < np.min(eye_stats['time_clusters'][j][1])):
|
||||
# raise ValueError
|
||||
|
||||
eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"]
|
||||
eye_stats["mean_area"] = np.mean(eye_stats["areas"])
|
||||
eye_stats["min_area"] = np.min(eye_stats["areas"])
|
||||
# eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"]
|
||||
# eye_stats["mean_area"] = np.mean(eye_stats["areas"])
|
||||
# eye_stats["min_area"] = np.min(eye_stats["areas"])
|
||||
|
||||
eye_stats["success"] = True
|
||||
except (RuntimeWarning, UserWarning, ValueError):
|
||||
eye_stats["success"] = False
|
||||
eye_stats["time_midpoint"] = 0
|
||||
eye_stats["levels"] = np.zeros(self.n_levels)
|
||||
eye_stats["amplitude_clusters"] = []
|
||||
eye_stats["amplitudes"] = np.zeros(self.n_levels - 1)
|
||||
eye_stats["heights"] = np.zeros(self.n_levels - 1)
|
||||
eye_stats["widths"] = np.zeros(self.n_levels - 1)
|
||||
eye_stats["areas"] = np.zeros(self.n_levels - 1)
|
||||
eye_stats["mean_area"] = 0
|
||||
eye_stats["min_area"] = 0
|
||||
eye_stats["time_midpoint"] = None
|
||||
eye_stats["levels"] = None
|
||||
eye_stats["thresholds"] = None
|
||||
eye_stats["amplitude_clusters"] = None
|
||||
eye_stats["amplitudes"] = None
|
||||
eye_stats["heights"] = None
|
||||
eye_stats["widths"] = None
|
||||
# eye_stats["areas"] = np.zeros(self.n_levels - 1)
|
||||
# eye_stats["mean_area"] = 0
|
||||
# eye_stats["min_area"] = 0
|
||||
warnings.resetwarnings()
|
||||
return eye_stats
|
||||
|
||||
@@ -441,7 +453,8 @@ if __name__ == "__main__":
|
||||
|
||||
data = generate_sample_data(length, noise=0.005)
|
||||
eye = eye_diagram(data, horizontal_bins=256, vertical_bins=256)
|
||||
attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths", "area", "mean_area", "min_area")
|
||||
eye.analyse()
|
||||
attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area")
|
||||
for i, channel in enumerate(eye.eye_stats):
|
||||
print(f"Channel {i}")
|
||||
print_data = {attr: channel[attr] for attr in attrs}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from .datasets import load_data
|
||||
if __name__ == "__main__":
|
||||
from datasets import load_data
|
||||
else:
|
||||
from .datasets import load_data
|
||||
|
||||
def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0, width=2, alpha=None, complex=False, show=True):
|
||||
"""Plot an eye diagram for the data given by filepath.
|
||||
@@ -20,6 +23,7 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0
|
||||
raise ValueError("Either path or data and sps must be given.")
|
||||
if path is not None:
|
||||
data, config = load_data(path, skipfirst, symbols)
|
||||
data = data.detach().cpu().numpy()[:, :4]
|
||||
sps = int(config["glova"]["sps"])
|
||||
if sps is None:
|
||||
raise ValueError("sps not set.")
|
||||
@@ -71,3 +75,6 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0
|
||||
plt.show()
|
||||
|
||||
return fig
|
||||
|
||||
if __name__ == "__main__":
|
||||
eye(path="data/20241229-163838-128-16384-50000-0-0.2-16.8-0.058-PAM4-0-0.16.ini", symbols=1000, width=2, alpha=0.1, complex=False)
|
||||
Reference in New Issue
Block a user