From f38d0ca3bb8c7d93c0680648d21ba8928bebd1b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joseph=20Hopfm=C3=BCller?= Date: Fri, 10 Jan 2025 23:40:54 +0100 Subject: [PATCH] model robustness testing --- .gitignore | 1 + notes/models.md | 37 + src/single-core-regen/hypertraining/models.py | 14 +- .../hypertraining/training.py | 69 +- src/single-core-regen/plot_model.py | 253 ++++++ src/single-core-regen/regen_no_hyper.py | 39 +- .../signal_gen}/add_pypho.py | 0 .../signal_gen}/generate_signal.py | 312 ++++---- src/single-core-regen/tolerance_testing.py | 723 ++++++++++++++++++ src/single-core-regen/util/complexNN.py | 35 +- src/single-core-regen/util/datasets.py | 283 ++++--- src/single-core-regen/util/eye_diagram.py | 117 +-- src/single-core-regen/util/plot.py | 9 +- 13 files changed, 1558 insertions(+), 334 deletions(-) create mode 100644 notes/models.md create mode 100644 src/single-core-regen/plot_model.py rename src/{single-core-data-gen => single-core-regen/signal_gen}/add_pypho.py (100%) rename src/{single-core-data-gen => single-core-regen/signal_gen}/generate_signal.py (69%) create mode 100644 src/single-core-regen/tolerance_testing.py diff --git a/.gitignore b/.gitignore index 75bcb68..1023046 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,4 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +tolerance_results/datasets/* diff --git a/notes/models.md b/notes/models.md new file mode 100644 index 0000000..b8073fb --- /dev/null +++ b/notes/models.md @@ -0,0 +1,37 @@ +# models + +## no polarisation flipping + +```py +config_path="data/20241229-163*-128-16384-50000-*.ini" +model=".models/best_20241230_011907.tar" +``` + +```py +config_path="data/20241229-163*-128-16384-80000-*.ini" +model=".models/best_20241230_103752.tar" +``` + +```py +config_path="data/20241229-163*-128-16384-100000-*.ini" +model=".models/best_20241230_164534.tar" +``` + +## with polarisation flipping + +polarisation flipping: signal is randomly rotated by 180°. polarization rotation can be detected by adding a tone on one of the polarisations, but only to mod 180° with a direct detection setup. the randomly flipped signal should allow the network to hopefully learn to compensate for dispersion, pmd independently from the polarization rot. the training data includes the flipped signal as well, but no indication if the polarisation is flipped. + +```py +config_path="data/20241229-163*-128-16384-50000-*.ini" +model=".models/best_20241231_000328.tar" +``` + +```py +config_path="data/20241229-163*-128-16384-80000-*.ini" +model=".models/best_20241231_163614.tar" +``` + +```py +config_path="data/20241229-163*-128-16384-100000-*.ini" +model=".models/best_20241231_170532.tar" +``` diff --git a/src/single-core-regen/hypertraining/models.py b/src/single-core-regen/hypertraining/models.py index 22240b1..756738c 100644 --- a/src/single-core-regen/hypertraining/models.py +++ b/src/single-core-regen/hypertraining/models.py @@ -124,7 +124,7 @@ class regenerator(Module): parametrizations: list[dict] = None, dtype=torch.float64, dropout_prob=0.01, - scale_layers=False, + prescale=1, rotate=False, ): super(regenerator, self).__init__() @@ -134,15 +134,14 @@ class regenerator(Module): act_func_kwargs = act_func_kwargs or {} self.rotation = rotate + self.prescale = prescale - self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers) + self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob) - def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers): + def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob): for i in range(0, self._n_hidden_layers): self.add_module(f"layer_{i}", Sequential()) - if scale_layers: - self.get_submodule(f"layer_{i}").add_module("scale", Scale(dims[i])) module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs) self.get_submodule(f"layer_{i}").add_module("ONN", module) @@ -156,8 +155,8 @@ class regenerator(Module): self.add_module(f"layer_{self._n_hidden_layers}", Sequential()) - if scale_layers: - self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2])) + # if scale_layers: + # self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2])) module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs) self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module) @@ -200,6 +199,7 @@ class regenerator(Module): return powers def forward(self, x, angle=None, pre_rot=False, trace_powers=False): + x = x * self.prescale powers = self._trace_powers(trace_powers, x) # x = self.layer_0(x) # powers = self._trace_powers(trace_powers, x, powers) diff --git a/src/single-core-regen/hypertraining/training.py b/src/single-core-regen/hypertraining/training.py index 22fb705..ff54316 100644 --- a/src/single-core-regen/hypertraining/training.py +++ b/src/single-core-regen/hypertraining/training.py @@ -683,7 +683,7 @@ class RegenerationTrainer: def define_model(self, model_kwargs=None): if self.resume: - model_kwargs = self.checkpoint_dict["model_kwargs"] + model_kwargs = None else: model_kwargs = model_kwargs @@ -692,6 +692,14 @@ class RegenerationTrainer: input_dim = 2 * self.data_settings.output_size + # if self.data_settings.polarisations: + # input_dim *= 2 + + output_dim = self.model_settings.output_dim + + # if self.data_settings.polarisations: + output_dim *= 2 + dtype = getattr(torch, self.data_settings.dtype) afunc = getattr(util.complexNN, self.model_settings.model_activation_func) @@ -703,7 +711,7 @@ class RegenerationTrainer: hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)] self.model_kwargs = { - "dims": (input_dim, *hidden_dims, self.model_settings.output_dim), + "dims": (input_dim, *hidden_dims, output_dim), "layer_function": layer_func, "layer_func_kwargs": self.model_settings.model_layer_kwargs, "act_function": afunc, @@ -711,7 +719,7 @@ class RegenerationTrainer: "parametrizations": layer_parametrizations, "dtype": dtype, "dropout_prob": self.model_settings.dropout_prob, - "scale_layers": self.model_settings.scale, + "prescale": self.model_settings.scale, } else: self.model_kwargs = model_kwargs @@ -745,11 +753,12 @@ class RegenerationTrainer: num_symbols = None config_path = self.data_settings.config_path randomise_polarisations = self.data_settings.randomise_polarisations + polarisations = self.data_settings.polarisations osnr = self.data_settings.osnr if override is not None: num_symbols = override.get("num_symbols", None) config_path = override.get("config_path", config_path) - # polarisations = override.get("polarisations", polarisations) + polarisations = override.get("polarisations", polarisations) randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations) # get dataset dataset = FiberRegenerationDataset( @@ -763,6 +772,7 @@ class RegenerationTrainer: real=not dtype.is_complex, num_symbols=num_symbols, randomise_polarisations=randomise_polarisations, + polarisations=polarisations, osnr = osnr, ) @@ -832,17 +842,19 @@ class RegenerationTrainer: running_loss = 0.0 self.model.train() loader_len = len(train_loader) + x_key = "x_stacked"# if self.data_settings.polarisations else "x" + y_key = "y_stacked"# if self.data_settings.polarisations else "y" for batch_idx, batch in enumerate(train_loader): - x = batch["x"] - y = batch["y"] - angles = batch["mean_angle"] + x = batch[x_key] + y = batch[y_key] + angle = batch["mean_angle"] self.model.zero_grad(set_to_none=True) - x, y, angles = ( + x, y, angle = ( x.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device), - angles.to(self.pytorch_settings.device), + angle.to(self.pytorch_settings.device), ) - y_pred = self.model(x, -angles) + y_pred = self.model(x, -angle) loss = util.complexNN.complex_mse_loss(y_pred, y, power=True) loss_value = loss.item() loss.backward() @@ -886,17 +898,19 @@ class RegenerationTrainer: self.model.eval() running_error = 0 + x_key = "x_stacked"# if self.data_settings.polarisations else "x" + y_key = "y_stacked"# if self.data_settings.polarisations else "y" with torch.no_grad(): for _, batch in enumerate(valid_loader): - x = batch["x"] - y = batch["y"] - angles = batch["mean_angle"] - x, y, angles = ( + x = batch[x_key] + y = batch[y_key] + angle = batch["mean_angle"] + x, y, angle = ( x.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device), - angles.to(self.pytorch_settings.device), + angle.to(self.pytorch_settings.device), ) - y_pred = self.model(x, -angles) + y_pred = self.model(x, -angle) error = util.complexNN.complex_mse_loss(y_pred, y, power=True) error_value = error.item() running_error += error_value @@ -953,15 +967,17 @@ class RegenerationTrainer: regen = [] timestamps = [] angles = [] + x_key = "x_stacked"# if self.data_settings.polarisations else "x" + y_key = "y_stacked"# if self.data_settings.polarisations else "y" with torch.no_grad(): model = model.to(self.pytorch_settings.device) for batch in loader: - x = batch["x"] - y = batch["y"] + x = batch[x_key] + y = batch[y_key] plot_target = batch["plot_target"] angle = batch["mean_angle"] - center_angle = batch["center_angle"] + # center_angle = batch["center_angle"] timestamp = batch["timestamp"] plot_data = batch["plot_data"] plot_data_rot = batch["plot_data_rot"] @@ -971,14 +987,16 @@ class RegenerationTrainer: angle.to(self.pytorch_settings.device), ) if trace_powers: - y_pred, powers = model(x, angle, True).cpu() + y_pred, powers = model(x, -angle, True).cpu() else: - y_pred = model(x, angle).cpu() + y_pred = model(x, -angle).cpu() # x = x.cpu() # y = y.cpu() + # if self.data_settings.polarisations: + y_pred = y_pred[:, :2] y_pred = y_pred.view(y_pred.shape[0], -1, 2) y_pred = y_pred[:, y_pred.shape[1]//2, :] - y = y.view(y.shape[0], -1, 2) + # y = y.view(y.shape[0], -1, 2) # plot_data = plot_data.view(plot_data.shape[0], -1, 2) # c = torch.cos(-angle).cpu() # s = torch.sin(-angle).cpu() @@ -996,7 +1014,7 @@ class RegenerationTrainer: fiber_in.append(plot_target.squeeze()) regen.append(y_pred.squeeze()) timestamps.append(timestamp.squeeze()) - angles.append(center_angle.squeeze()) + angles.append(angle.squeeze()) fiber_out = torch.vstack(fiber_out).cpu() fiber_out_rot = torch.vstack(fiber_out_rot).cpu() @@ -1352,7 +1370,8 @@ class RegenerationTrainer: "num_symbols": self.pytorch_settings.batchsize, "config_path": config_path, "shuffle": False, - "polarisations": (np.random.rand(1) * np.pi * 2,), + # "polarisations": (np.random.rand(1) * np.pi * 2,), + "polarisations": self.data_settings.polarisations, "randomise_polarisation": self.data_settings.randomise_polarisations, } ) @@ -1366,7 +1385,7 @@ class RegenerationTrainer: fiber_out_rot = fiber_out_rot.view(-1, 2) angles = angles.view(-1, 1) angles = angles.real - angles = torch.fmod(angles, 2 * torch.pi) + angles = torch.fmod(angles, 2*torch.pi) angles = torch.div(angles, 2*torch.pi) angles = torch.repeat_interleave(angles, 2, dim=1) diff --git a/src/single-core-regen/plot_model.py b/src/single-core-regen/plot_model.py new file mode 100644 index 0000000..27eab39 --- /dev/null +++ b/src/single-core-regen/plot_model.py @@ -0,0 +1,253 @@ +import os +from matplotlib import pyplot as plt +import numpy as np +import torch +import util +from hypertraining.settings import GlobalSettings, DataSettings, ModelSettings, OptimizerSettings, PytorchSettings +from hypertraining import models + +# def move_to_location_in_size(array, location, size): +# array_x, array_y = array.shape +# location_x, location_y = location +# size_x, size_y = size + +# left = location_x +# right = size_x - array_x - location_x + +# top = location_y +# bottom = size_y - array_y - location_y + +# return np.pad( +# array, +# ( +# (left, right), +# (top, bottom), +# ), +# constant_values=(-np.inf, -np.inf), +# ) + + +def pad_to_size(array, size): + if not hasattr(size, "__len__"): + size = (size, size) + + left = ( + (size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0 + ) + right = ( + (size[0] - array.shape[0]) // 2 if size[0] is not None else 0 + ) + top = ( + (size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0 + ) + bottom = ( + (size[1] - array.shape[1]) // 2 if size[1] is not None else 0 + ) + + array: np.ndarray = array + if array.ndim == 2: + return np.pad( + array, + ( + (left, right), + (top, bottom), + ), + constant_values=(np.nan, np.nan), + ) + elif array.ndim == 3: + return np.pad( + array, + ( + (left, right), + (top, bottom), + (0,0) + ), + constant_values=(np.nan, np.nan), + ) + +def model_plot(model_path): + torch.serialization.add_safe_globals([ + *util.complexNN.__all__, + GlobalSettings, + DataSettings, + ModelSettings, + OptimizerSettings, + PytorchSettings, + models.regenerator, + torch.nn.utils.parametrizations.orthogonal, + ]) + checkpoint_dict = torch.load(model_path, weights_only=True) + + dims = checkpoint_dict["model_kwargs"].pop("dims") + + model = models.regenerator(*dims, **checkpoint_dict["model_kwargs"]) + model.load_state_dict(checkpoint_dict["model_state_dict"]) + + model_params = [] + plots = [] + max_size = np.max(dims) + # max_act_size = np.max(dims[1:]) + + angles = [None, None] + weights = [None, None] + + for num, (layer_name, layer) in enumerate(model.named_children()): + # each layer contains an "ONN" layer and an "activation" layer + # activation layer is approximately the same for all layers and nodes -> rotation by 90 degrees + onn_weights = layer.ONN.weight.T + onn_weights = onn_weights.detach().cpu().numpy() + onn_values = np.abs(onn_weights).real + onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real + + + act = layer.activation + + act_values = np.ones((act.size, 1)) + + act_values = np.nan * act_values + + act_angles = act.phase.unsqueeze(-1).detach().cpu().numpy() + ... + # act_phi_bias = torch.pi * act.V_bias / (act.V_pi + 1e-8) + # act_phi_gain = torch.pi * (act.alpha * act.gain * act.responsivity) / (act.V_pi + 1e-8) + # xs = (0.01, 0.1, 1) + + # act_values = np.zeros((act.size, len(xs)*2)) + # act_angles = np.zeros((act.size, len(xs)*2)) + + # act_values[:,:] = np.nan + # act_angles[:,:] = np.nan + + # for xi, x in enumerate(xs): + # phi_intermediate = act_phi_gain * x**2 + act_phi_bias + + # act_resulting_gain = ( + # 1j + # * torch.sqrt(1-act.alpha) + # * torch.exp(-0.5j * phi_intermediate) + # * torch.cos(0.5 * phi_intermediate) + # * x + # ) + + # act_resulting_gain = act_resulting_gain.detach().cpu().numpy() + # act_values[:, xi*2] = np.abs(act_resulting_gain).real + # act_angles[:, xi*2] = np.mod(np.angle(act_resulting_gain), 2*np.pi).real + + + + # if angles[0] is None or angles[0] > np.min(onn_angles.flatten()): + # angles[0] = np.min(onn_angles.flatten()) + # if angles[1] is None or angles[1] < np.max(onn_angles.flatten()): + # angles[1] = np.max(onn_angles.flatten()) + # if weights[0] is None or weights[0] > np.min(onn_weights.flatten()): + # weights[0] = np.min(onn_weights.flatten()) + # if weights[1] is None or weights[1] < np.max(onn_weights.flatten()): + # weights[1] = np.max(onn_weights.flatten()) + + model_params.append({layer_name: onn_weights}) + plots.append({layer_name: (num, onn_values, onn_angles, act_values, act_angles)}) + + # fig, axs = plt.subplots(3, len(model_params)*2-1, figsize=(20, 5)) + + for plot in plots: + layer_name, (num, onn_values, onn_angles, act_values, act_angles) = plot.popitem() + # for_plot[:, :, 0] = (for_plot[:, :, 0] - angles[0]) / (angles[1] - angles[0]) + # for_plot[:, :, 1] = (for_plot[:, :, 1] - weights[0]) / (weights[1] - weights[0]) + + onn_values = np.ma.array(onn_values, mask=np.isnan(onn_values)) + onn_values = onn_values - np.min(onn_values) + onn_values = onn_values / np.max(onn_values) + + act_values = np.ma.array(act_values, mask=np.isnan(act_values)) + act_values = act_values - np.min(act_values) + act_values = act_values / np.max(act_values) + + + onn_values = onn_values + onn_values = pad_to_size(onn_values, (max_size, None)) + + act_values = act_values + act_values = pad_to_size(act_values, (max_size, 3)) + + onn_angles = onn_angles / np.pi + onn_angles = pad_to_size(onn_angles, (max_size, None)) + + act_angles = act_angles / np.pi + act_angles = pad_to_size(act_angles, (max_size, 3)) + + + + # onn_angles = onn_angles - np.min(onn_angles) + # onn_angles = onn_angles / np.max(onn_angles) + + # act_angles = act_angles - np.min(act_angles) + # act_angles = act_angles / np.max(act_angles) + + if num == 0: + value_img = np.concatenate((onn_values, act_values), axis=1) + angle_img = np.concatenate((onn_angles, act_angles), axis=1) + else: + value_img = np.concatenate((value_img, onn_values, act_values), axis=1) + angle_img = np.concatenate((angle_img, onn_angles, act_angles), axis=1) + + + + + # -np.inf to np.nan + # value_img[value_img == -np.inf] = np.nan + + # angle_img += move_to_location_in_size(onn_angles, ((max_size+3)*num, 0), img_overall_size) + # angle_img += move_to_location_in_size(act_angles, ((max_size+3)*(num+1) + 2, 0), img_overall_size) + + + + + from cmcrameri import cm + from matplotlib import colors as mcolors + alpha_map = mcolors.LinearSegmentedColormap( + 'alphamap', + { + 'red': [(0, 0, 0), (1, 0, 0)], + 'green': [(0, 0, 0), (1, 0, 0)], + 'blue': [(0, 0, 0), (1, 0, 0)], + 'alpha': [ + (0, 1, 1), + # (0.2, 0.2, 0.1), + (1, 0, 0) + ] + } + ) + alpha_map.set_bad(color="#AAAAAA") + + + fig, axs = plt.subplots(3, 1, sharex=True, sharey=True, figsize=(7, 8.5)) + fig.tight_layout() + # masked_value_img = np.ma.masked_where(np.isnan(value_img), value_img) + masked_value_img = value_img + cmap = cm.batlowW + cmap.set_bad(color="#AAAAAA") + im_val = axs[0].imshow(masked_value_img, cmap=cmap) + + masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img) + cmap = cm.romaO + cmap.set_bad(color="#AAAAAA") + im_ang = axs[1].imshow(masked_angle_img, cmap=cmap) + im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap) + im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map) + + axs[0].axis("off") + axs[1].axis("off") + axs[2].axis("off") + + axs[0].set_title("Values") + axs[1].set_title("Angles") + axs[2].set_title("Values and Angles") + + + ... + plt.show() + # model = models.regenerator(*dims, **model_kwargs) + + +if __name__ == "__main__": + model_plot(".models/best_20250105_145719.tar") diff --git a/src/single-core-regen/regen_no_hyper.py b/src/single-core-regen/regen_no_hyper.py index 84ca548..d68feeb 100644 --- a/src/single-core-regen/regen_no_hyper.py +++ b/src/single-core-regen/regen_no_hyper.py @@ -1,4 +1,4 @@ -from datetime import datetime +# from datetime import datetime from pathlib import Path import matplotlib import numpy as np @@ -13,7 +13,7 @@ from hypertraining.settings import ( OptimizerSettings, ) -from hypertraining.training import RegenerationTrainer, PolarizationTrainer +from hypertraining.training import RegenerationTrainer#, PolarizationTrainer # import torch import json @@ -27,7 +27,7 @@ global_settings = GlobalSettings( data_settings = DataSettings( # config_path="data/*-128-16384-1-0-0-0-0-PAM4-0-0.ini", - config_path="data/*-128-16384-10000-0-0-17-0-PAM4-0.ini", + config_path="data/20250110-190528-128-16384-100000-0-0.2-17.0-0.058-PAM4-0-0.14-10.ini", # config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)], dtype="complex64", # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber @@ -37,12 +37,13 @@ data_settings = DataSettings( shuffle=True, drop_first=64, train_split=0.8, - randomise_polarisations=True, - osnr=10, + randomise_polarisations=False, + polarisations=True, + osnr=16, #16dB due to amplification with NF 5 ) pytorch_settings = PytorchSettings( - epochs=10000, + epochs=1000, batchsize=2**14, device="cuda", dataloader_workers=24, @@ -64,11 +65,11 @@ model_settings = ModelSettings( # "n_hidden_nodes_3": 4, # "n_hidden_nodes_4": 2, }, - model_activation_func="EOActivation", + model_activation_func="phase_shift", dropout_prob=0, model_layer_function="ONNRect", model_layer_kwargs={"square": True}, - scale=False, + scale=2.0, model_layer_parametrizations=[ { "tensor_name": "weight", @@ -77,13 +78,17 @@ model_settings = ModelSettings( { "tensor_name": "alpha", "parametrization": util.complexNN.clamp, + "kwargs": { + "min": 0, + "max": 1, + }, }, { "tensor_name": "gain", "parametrization": util.complexNN.clamp, "kwargs": { "min": 0, - "max": float("inf"), + "max": None, }, }, { @@ -95,8 +100,12 @@ model_settings = ModelSettings( }, }, { - "tensor_name": "scales", + "tensor_name": "scale", "parametrization": util.complexNN.clamp, + "kwargs": { + "min": 0, + "max": 2, + }, }, { "tensor_name": "angle", @@ -244,9 +253,17 @@ if __name__ == "__main__": pytorch_settings=pytorch_settings, model_settings=model_settings, optimizer_settings=optimizer_settings, - checkpoint_path=".models/best_20241216_221359.tar", + # checkpoint_path=".models/best_20250104_191428.tar", reset_epoch=True, # settings_override={ + # "data_settings": { + # "config_path": "data/20241229-163*-128-16384-100000-*.ini", + # "polarisations": True, + # }, + # "model_settings": { + # "scale": 2.0, + # } + # } # "optimizer_settings": { # "optimizer_kwargs": { # "lr": 0.01, diff --git a/src/single-core-data-gen/add_pypho.py b/src/single-core-regen/signal_gen/add_pypho.py similarity index 100% rename from src/single-core-data-gen/add_pypho.py rename to src/single-core-regen/signal_gen/add_pypho.py diff --git a/src/single-core-data-gen/generate_signal.py b/src/single-core-regen/signal_gen/generate_signal.py similarity index 69% rename from src/single-core-data-gen/generate_signal.py rename to src/single-core-regen/signal_gen/generate_signal.py index 272a06e..d58b28f 100644 --- a/src/single-core-data-gen/generate_signal.py +++ b/src/single-core-regen/signal_gen/generate_signal.py @@ -16,16 +16,17 @@ from datetime import datetime import hashlib from pathlib import Path import time +import h5py from matplotlib import pyplot as plt # noqa: F401 import numpy as np -import add_pypho # noqa: F401 +from . import add_pypho # noqa: F401 import pypho default_config = f""" [glova] -nos = 256 -sps = 256 +sps = 128 +nos = 16384 f0 = 193414489032258.06 symbolrate = 10e9 wisdom_dir = "{str((Path.home() / ".pypho"))}" @@ -37,9 +38,9 @@ length = 10000 gamma = 1.14 alpha = 0.2 D = 17 -S = 0 -birefsteps = 0 -max_delta_beta = 0.4 +S = 0.058 +bireflength = 10 +max_delta_beta = 0.14 ; birefseed = 0xC0FFEE [signal] @@ -47,17 +48,15 @@ max_delta_beta = 0.4 modulation = "pam" mod_order = 4 -mod_depth = 0.8 - +mod_depth = 1 max_jitter = 0.02 ; jitter_seed = 0xC0FFEE - laser_power = 0 -edfa_power = 3 +edfa_power = 0 edfa_nf = 5 - pulse_shape = "gauss" fwhm = 0.33 +osnr = "inf" [data] dir = "data" @@ -71,6 +70,7 @@ def get_config(config_file=None): """ if config_file is None: config_file = Path(__file__).parent / "signal_generation.ini" + config_file = Path(config_file) if not config_file.exists(): with open(config_file, "w") as f: f.write(default_config) @@ -83,7 +83,10 @@ def get_config(config_file=None): conf[section] = {} for key in config[section]: # print(f"{key} = {config[section][key]}") - conf[section][key] = eval(config[section][key]) + try: + conf[section][key] = eval(config[section][key]) + except NameError: + conf[section][key] = float(config[section][key]) # if isinstance(conf[section][key], str): # conf[section][key] = config[section][key].strip('"') return conf @@ -96,7 +99,9 @@ class PDM_IM_IPM: mod_order=8, seed=None, ): - assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, "mod_order must be a cube of an integer greater than 1" + assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, ( + "mod_order must be a cube of an integer greater than 1" + ) self.glova = glova self.mod_order = mod_order self.symbols_per_dim = int(np.cbrt(mod_order)) @@ -106,18 +111,11 @@ class PDM_IM_IPM: rs = np.random.RandomState(self.seed) symbols = rs.randint(0, self.mod_order, n) return symbols - + class pam_generator: def __init__( - self, - glova, - mod_order=None, - mod_depth=0.5, - pulse_shape="gauss", - fwhm=0.33, - seed=None, - single_channel=False + self, glova, mod_order=None, mod_depth=0.5, pulse_shape="gauss", fwhm=0.33, seed=None, single_channel=False ) -> None: self.glova = glova self.pulse_shape = pulse_shape @@ -133,41 +131,36 @@ class pam_generator: wavelet = self.gauss(oversampling=6) else: raise ValueError(f"Unknown pulse shape: {self.pulse_shape}") - + # prepare symbols symbols_x = symbols[0] / (self.mod_order) diffs_x = np.diff(symbols_x, prepend=symbols_x[0]) digital_x = self.generate_digital_signal(diffs_x, max_jitter) - digital_x = np.pad( - digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0) - ) - + digital_x = np.pad(digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)) + # create analog signal of diff of symbols E_x = np.convolve(digital_x, wavelet) - + # convert to pam and set modulation depth (scale and move up such that 1 stays at 1) E_x = np.cumsum(E_x) * self.modulation_depth + (1 - self.modulation_depth) - + # cut off the wavelet tails E_x = E_x[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2] - + # modulate the laser E[0]["E"][0] = np.sqrt(np.multiply(np.square(E[0]["E"][0]), E_x)) - + if not self.single_channel: symbols_y = symbols[1] / (self.mod_order) diffs_y = np.diff(symbols_y, prepend=symbols_y[0]) digital_y = self.generate_digital_signal(diffs_y, max_jitter) - digital_y = np.pad( - digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0) - ) + digital_y = np.pad(digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)) E_y = np.convolve(digital_y, wavelet) E_y = np.cumsum(E_y) * self.modulation_depth + (1 - self.modulation_depth) E_y = E_y[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2] - E[0]["E"][1] = np.sqrt(np.multiply(np.square(E[0]["E"][1]), E_y)) # rotate the signal on the y-polarisation by 90° @@ -175,7 +168,6 @@ class pam_generator: else: E[0]["E"][1] = np.zeros_like(E[0]["E"][0], dtype=E[0]["E"][0].dtype) return E - def generate_digital_signal(self, symbols, max_jitter=0): rs = np.random.RandomState(self.seed) @@ -198,15 +190,11 @@ class pam_generator: endpoint=True, ) sigma = self.fwhm / (1 * np.sqrt(2 * np.log(2))) * self.glova.sps - pulse = ( - 1 - / (sigma * np.sqrt(2 * np.pi)) - * np.exp(-np.square(sample_points) / (2 * np.square(sigma))) - ) + pulse = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-np.square(sample_points) / (2 * np.square(sigma))) return pulse -def initialize_fiber_and_data(config, input_data_override=None): +def initialize_fiber_and_data(config): py_glova = pypho.setup( nos=config["glova"]["nos"], sps=config["glova"]["sps"], @@ -221,48 +209,54 @@ def initialize_fiber_and_data(config, input_data_override=None): c_data = pypho.cfiber.DataWrapper(py_glova.sps * py_glova.nos) py_edfa = pypho.oamp(py_glova, Pmean=config["signal"]["edfa_power"], NF=config["signal"]["edfa_nf"]) - if input_data_override is not None: - c_data.E_in = input_data_override[0] - noise = input_data_override[1] - else: - config["signal"]["seed"] = config["signal"].get( - "seed", (int(time.time() * 1000)) % 2**32 - ) - config["signal"]["jitter_seed"] = config["signal"].get( - "jitter_seed", (int(time.time() * 1000)) % 2**32 - ) - symbolsrc = pypho.symbols( - py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"] - ) - laser = pypho.lasmod( - py_glova, power=config["signal"]["laser_power"]+1.5, Df=0, theta=np.pi / 4 - ) - modulator = pam_generator( - py_glova, - mod_depth=config["signal"]["mod_depth"], - pulse_shape=config["signal"]["pulse_shape"], - fwhm=config["signal"]["fwhm"], - seed=config["signal"]["jitter_seed"], - mod_order=config["signal"]["mod_order"], + osnr = config["signal"]["osnr"] = float(config["signal"].get("osnr", "inf")) + + config["signal"]["seed"] = config["signal"].get("seed", (int(time.time() * 1000)) % 2**32) + config["signal"]["jitter_seed"] = config["signal"].get("jitter_seed", (int(time.time() * 1000)) % 2**32) + symbolsrc = pypho.symbols( + py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"] + ) + laser = pypho.lasmod(py_glova, power=config["signal"]["laser_power"], Df=0, theta=np.pi / 4) + modulator = pam_generator( + py_glova, + mod_depth=config["signal"]["mod_depth"], + pulse_shape=config["signal"]["pulse_shape"], + fwhm=config["signal"]["fwhm"], + seed=config["signal"]["jitter_seed"], + mod_order=config["signal"]["mod_order"], + ) + + symbols_x = symbolsrc(pattern="random") + symbols_y = symbolsrc(pattern="random") + symbols_x[:3] = 0 + symbols_y[:3] = 0 + # symbols_x += 1 + + cw = laser() + + source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y)) + + if osnr != float("inf"): + osnr_lin = 10 ** (osnr / 10) + signal_power = np.sum(pypho.functions.getpower_W(source_signal[0]["E"])) + noise_power = signal_power / osnr_lin + noise = np.random.normal(0, 1, source_signal[0]["E"].shape) + 1j * np.random.normal( + 0, 1, source_signal[0]["E"].shape ) + noise_power_is = np.sum(pypho.functions.getpower_W(noise)) + noise = noise * np.sqrt(noise_power / noise_power_is) + noise_power_is = np.sum(pypho.functions.getpower_W(noise)) + source_signal[0]["E"] += noise + source_signal[0]["noise"] = noise_power_is - symbols_x = symbolsrc(pattern="random") - symbols_y = symbolsrc(pattern="random") - symbols_x[:3] = 0 - symbols_y[:3] = 0 - # symbols_x += 1 + # source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))] + nf = py_edfa.NF + source_signal = py_edfa(E=source_signal, NF=0) + py_edfa.NF = nf - cw = laser() - - source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y)) - - # source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))] - - source_signal = py_edfa(E=source_signal) - - c_data.E_in = source_signal[0]["E"] - noise = source_signal[0]["noise"] + c_data.E_in = source_signal[0]["E"] + noise = source_signal[0]["noise"] py_fiber = pypho.fiber( glova=py_glova, @@ -273,25 +267,21 @@ def initialize_fiber_and_data(config, input_data_override=None): S=config["fiber"]["s"], ) if config["fiber"].get("birefsteps", 0) > 0: - seed = config["fiber"].get( - "birefseed", (int(time.time() * 1000)) % 2**32 - ) + seed = config["fiber"].get("birefseed", (int(time.time() * 1000)) % 2**32) py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre( py_fiber.l, py_fiber.l / config["fiber"]["birefsteps"], # maxDeltaD=config["fiber"]["d"]/5, - maxDeltaBeta = config["fiber"].get("max_delta_beta", 0), + maxDeltaBeta=config["fiber"].get("max_delta_beta", 0), seed=seed, ) - c_params = pypho.cfiber.ParamsWrapper.from_fiber( - py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200 - ) + c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200) c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova) - return c_fiber, c_data, noise, py_edfa + return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y) -def save_data(data, config): +def save_data(data, config, **metadata): data_dir = Path(config["data"]["dir"]) npy_dir = config["data"].get("npy_dir", "") save_dir = data_dir / npy_dir if len(npy_dir) else data_dir @@ -306,6 +296,7 @@ def save_data(data, config): seed = config["signal"].get("seed", False) jitter_seed = config["signal"].get("jitter_seed", False) birefseed = config["fiber"].get("birefseed", False) + osnr = float(config["signal"].get("osnr", "inf")) config_content = "\n".join(( f"; Generated by {str(Path(__file__).name)} @ {timestamp.strftime('%Y-%m-%d %H:%M:%S')}", @@ -317,14 +308,14 @@ def save_data(data, config): f'wisdom_dir = "{config["glova"]["wisdom_dir"]}"', f'flags = "{config["glova"]["flags"]}"', f"nthreads = {config['glova']['nthreads']}", - " ", + "", "[fiber]", f"length = {config['fiber']['length']}", f"gamma = {config['fiber']['gamma']}", f"alpha = {config['fiber']['alpha']}", f"D = {config['fiber']['d']}", f"S = {config['fiber']['s']}", - f"birefsteps = {config['fiber'].get('birefsteps',0)}", + f"birefsteps = {config['fiber'].get('birefsteps', 0)}", f"max_delta_beta = {config['fiber'].get('max_delta_beta', 0)}", f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set", "", @@ -334,75 +325,84 @@ def save_data(data, config): f'modulation = "{config["signal"]["modulation"]}"', f"mod_order = {config['signal']['mod_order']}", f"mod_depth = {config['signal']['mod_depth']}", - "" + "", f"max_jitter = {config['signal']['max_jitter']}", f"jitter_seed = {hex(jitter_seed)}" if jitter_seed else "; jitter_seed = not set", - "" + "", f"laser_power = {config['signal']['laser_power']}", f"edfa_power = {config['signal']['edfa_power']}", f"edfa_nf = {config['signal']['edfa_nf']}", - "" + f"osnr = {osnr}", + "", f'pulse_shape = "{config["signal"]["pulse_shape"]}"', f"fwhm = {config['signal']['fwhm']}", "", "[data]", f'dir = "{str(data_dir)}"', f'npy_dir = "{npy_dir}"', - "file = " + "file = ", )) config_hash = hashlib.md5(config_content.encode()).hexdigest() - save_file = f"{config_hash}.npy" + save_file = f"{config_hash}.h5" config_content += f'"{str(save_file)}"\n' filename_components = ( timestamp.strftime("%Y%m%d-%H%M%S"), config["glova"]["sps"], config["glova"]["nos"], + config["signal"]["osnr"], config["fiber"]["length"], config["fiber"]["gamma"], config["fiber"]["alpha"], config["fiber"]["d"], config["fiber"]["s"], f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}", - config['fiber'].get('birefsteps',0), + config["fiber"].get("birefsteps", 0), config["fiber"].get("max_delta_beta", 0), + int(config["glova"]["symbolrate"] / 1e9), ) lookup_file = "-".join(map(str, filename_components)) + ".ini" - with open(data_dir / lookup_file, "w") as f: + config_filename = data_dir / lookup_file + with open(config_filename, "w") as f: f.write(config_content) - np.save(save_dir / save_file, save_data) + with h5py.File(save_dir / save_file, "w") as outfile: + outfile.create_dataset("data", data=save_data) + outfile.create_dataset("symbols", data=metadata.pop("symbols")) + for key, value in metadata.items(): + # if isinstance(value, dict): + # value = json.dumps(model_runner.convert_arrays(value)) + outfile.attrs[key] = value + # np.save(save_dir / save_file, save_data) - print("Saved config to", data_dir / lookup_file) + print("Saved config to", config_filename) print("Saved data to", save_dir / save_file) + return config_filename + def length_loop(config, lengths, save=True): lengths = sorted(lengths) for length in lengths: print(f"\nGenerating data for fiber length {length}m") - config["fiber"]["length"] = length - - cfiber, cdata, noise, edfa = initialize_fiber_and_data(config) + config["fiber"]["length"] = length + cfiber, cdata, noise, edfa = initialize_fiber_and_data(config) mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in)) cfiber() mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out)) + print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)") + print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)") - print( - f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)" - ) - print( - f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)" - ) - - - E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}] + E_tmp = [{"E": cdata.E_out, "noise": noise * (-cfiber.params.l * cfiber.params.alpha)}] E_tmp = edfa(E=E_tmp) - cdata.E_out = E_tmp[0]['E'] + cdata.E_out = E_tmp[0]["E"] + + mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out)) + print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)") if save: save_data(cdata, config) @@ -411,27 +411,57 @@ def length_loop(config, lengths, save=True): def single_run_with_plot(config, save=True): - cfiber, cdata, noise, edfa = initialize_fiber_and_data(config) + cfiber, cdata, config_filename = single_run(config, save) - mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in)) - print( - f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)" - ) + in_out_eyes(cfiber, cdata, show_pols=False) + return config_filename + +def single_run(config, save=True): + cfiber, cdata, noise, edfa, symbols = initialize_fiber_and_data(config) + + # mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in)) + # print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)") + + # estimate osnr + # noise_power = np.mean(noise) + # osnr_lin = mean_power_in / noise_power - 1 + # osnr = 10 * np.log10(osnr_lin) + # print(f"Estimated OSNR: {osnr:.3f} dB") cfiber() - mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out)) - print( - f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)" - ) + # mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out)) + # print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)") - E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}] + # noise = noise * np.exp(-cfiber.params.l * cfiber.params.alpha) + + # estimate osnr + # noise_power = np.mean(noise) + # osnr_lin = mean_power_out / noise_power - 1 + # osnr = 10 * np.log10(osnr_lin) + # print(f"Estimated OSNR: {osnr:.3f} dB") + + E_tmp = [{"E": cdata.E_out, "noise": noise}] E_tmp = edfa(E=E_tmp) - cdata.E_out = E_tmp[0]['E'] - if save: - save_data(cdata, config) + cdata.E_out = E_tmp[0]["E"] + # noise = E_tmp[0]["noise"] + + # mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out)) + + # print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)") + + # estimate osnr + # noise_power = np.mean(noise) + # osnr_lin = mean_power_amp / noise_power - 1 + # osnr = 10 * np.log10(osnr_lin) + # print(f"Estimated OSNR: {osnr:.3f} dB") + + config_filename = None + symbols = np.array(symbols) + if save: + config_filename = save_data(cdata, config, **{"symbols": symbols}) + return cfiber,cdata,config_filename - in_out_eyes(cfiber, cdata, show_pols=False) def in_out_eyes(cfiber, cdata, show_pols=False): fig, axs = plt.subplots(2, 2, sharex=True, sharey=True) @@ -595,9 +625,7 @@ def plot_eye_diagram( signal = signal[: head * eye_width] if normalize: signal = signal / np.max(signal) - slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[ - offset % (eye_width + 1) :: eye_width - ] + slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[offset % (eye_width + 1) :: eye_width] plt_ax = np.arange(-eye_width // 2, eye_width // 2 + 1) / samplerate for slice in slices: ax.plot(plt_ax, slice, color=color, alpha=0.1) @@ -617,15 +645,27 @@ if __name__ == "__main__": # lengths.append(10*max(ranges)) # lengths = [*lengths, *lengths] lengths = ( - # 8000, 9000, - 10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000, - 95000, 100000, 105000, 110000, 115000, 120000 + # 8000, 9000, + 10000, + 20000, + 30000, + 40000, + 50000, + 60000, + 70000, + 80000, + 90000, + 95000, + 100000, + 105000, + 110000, + 115000, + 120000, ) # lengths = (10000,100000) - length_loop(config, lengths, save=True) + # length_loop(config, lengths, save=True) # birefringence is constant over coupling length -> several 100m -> bireflength=1000 (m) - # single_run_with_plot(config, save=False) - + single_run_with_plot(config, save=False) diff --git a/src/single-core-regen/tolerance_testing.py b/src/single-core-regen/tolerance_testing.py new file mode 100644 index 0000000..92b6329 --- /dev/null +++ b/src/single-core-regen/tolerance_testing.py @@ -0,0 +1,723 @@ +""" +tests a given model for tolerance against variations in +- fiber length +- baudrate +- OSNR + +CD, PMD, baudrate need different datasets, osnr is modeled as awgn added to the data before feeding into the model +""" + +from datetime import datetime +from typing import Literal +from matplotlib import pyplot as plt +import numpy as np +from pathlib import Path + +import h5py +import torch +import util +from hypertraining.settings import GlobalSettings, DataSettings, ModelSettings, OptimizerSettings, PytorchSettings +from hypertraining import models + +from signal_gen.generate_signal import single_run, get_config + +import json + + +class NestedParameterIterator: + def __init__(self, parameters): + """ + parameters: dict with key and value + """ + # self.parameters = parameters + self.names = [] + self.ranges = [] + self.configs = [] + for k, v in parameters.items(): + self.names.append(k) + self.ranges.append(v["range"]) + self.configs.append(v["config"]) + self.n_parameters = len(self.ranges) + + self.idx = 0 + + self.range_idx = [0] * self.n_parameters + self.range_len = [len(r) for r in self.ranges] + self.length = int(np.prod(self.range_len)) + self.out = [] + + for i in range(self.length): + self.out.append([]) + for j in range(self.n_parameters): + element = {self.names[j]: {"value": self.ranges[j][self.range_idx[j]], "config": self.configs[j]}} + self.out[i].append(element) + self.range_idx[-1] += 1 + + # update range_idx back to front + for j in range(self.n_parameters - 1, -1, -1): + if self.range_idx[j] == self.range_len[j]: + self.range_idx[j] = 0 + self.range_idx[j - 1] += 1 + + ... + + def __next__(self): + if self.idx == self.length: + raise StopIteration + self.idx += 1 + return self.out[self.idx - 1] + + def __iter__(self): + return self + + +class model_runner: + def __init__( + self, + # length_range: tuple[int | float] = (50e3, 50e3), + # length_steps: int = 1, + # length_log: bool = False, + # baudrate_range: tuple[int | float] = (10e9, 10e9), + # baudrate_steps: int = 1, + # baudrate_log: bool = False, + # osnr_range: tuple[int | float] = (16, 16), + # osnr_steps: int = 1, + # osnr_log: bool = False, + # dataset_dir: str = "data", + # dataset_datetime_glob: str = "*", + results_dir: str = "tolerance_results/datasets", + # model_dir: str = ".models", + config: str = "signal_generation.ini", + config_dir: str = None, + debug: bool = False, + ): + """ + length_range: lower and upper limit of length, in meters + length_step: step size of length, in meters + baudrate_range: lower and upper limit of baudrate, in Bd + baudrate_step: step size of baudrate, in Bd + osnr_range: lower and upper limit of osnr, in dB + osnr_step: step size of osnr, in dB + dataset_dir: directory containing datasets + dataset_datetime_glob: datetime glob pattern for dataset files + results_dir: directory to save results + model_dir: directory containing models + """ + self.debug = debug + + self.parameters = {} + self.iter = None + + # self.update_length_range(length_range, length_steps, length_log) + # self.update_baudrate_range(baudrate_range, baudrate_steps, baudrate_log) + # self.update_osnr_range(osnr_range, osnr_steps, osnr_log) + + # self.data_dir = Path(dataset_dir) + # self.data_datetime_glob = dataset_datetime_glob + self.results_dir = Path(results_dir) + # self.model_dir = Path(model_dir) + + config_dir = config_dir or Path(__file__).parent + self.config = config_dir / config + + torch.serialization.add_safe_globals([ + *util.complexNN.__all__, + GlobalSettings, + DataSettings, + ModelSettings, + OptimizerSettings, + PytorchSettings, + models.regenerator, + torch.nn.utils.parametrizations.orthogonal, + ]) + + self.load_model() + + self.datasets = [] + + # def register_parameter(self, name, config): + # self.parameters.append({"name": name, "config": config}) + + def load_results_from_file(self, path): + data, meta = self.load_from_file(path) + self.results = [d.decode() for d in data] + self.parameters = meta["parameters"] + ... + + def load_datasets_from_file(self, path): + data, meta = self.load_from_file(path) + self.datasets = [d.decode() for d in data] + self.parameters = meta["parameters"] + ... + + def update_parameter_range(self, name, config, range, steps, log): + self.parameters[name] = {"config": config, "range": self.update_range(*range, steps, log)} + + def generate_iterations(self): + if len(self.parameters) == 0: + raise ValueError("No parameters registered") + self.iter = NestedParameterIterator(self.parameters) + + def generate_datasets(self): + # get base config + config = get_config(self.config) + + if self.iter is None: + self.generate_iterations() + + for params in self.iter: + current_settings = [] + # params is a list of dictionaries with keys "name", containing a dict with keys "value", "config" + for param in params: + for name, settings in param.items(): + current_settings.append({name: settings["value"]}) + self.nested_set(config, settings["config"], settings["value"]) + settings_strs = [] + for setting in current_settings: + name = list(setting)[0] + settings_strs.append(f"{name}: {float(setting[name]):.2e}") + settings_str = ", ".join(settings_strs) + print(f"Generating dataset for [{settings_str}]") + # TODO look for existing datasets + _, _, path = single_run(config) + self.datasets.append(str(path)) + + datasets_list_path = self.build_path("datasets_list", parent_dir=self.results_dir, timestamp="back") + metadata = {"parameters": self.parameters} + data = np.array(self.datasets, dtype="S") + self.save_to_file(datasets_list_path, data, **metadata) + + @staticmethod + def nested_set(dic, keys, value): + for key in keys[:-1]: + dic = dic.setdefault(key, {}) + dic[keys[-1]] = value + + ## Dataset and model loading + # def find_datasets(self, data_dir=None, data_datetime_glob=None): + # # date-time-sps-nos-length-gamma-alpha-D-S-PAM4-birefsteps-deltabeta-symbolrate.ini + # data_dir = data_dir or self.data_dir + # data_datetime_glob = data_datetime_glob or self.data_datetime_glob + # self.datasets = {} + # data_dir = Path(data_dir) + # for length in self.lengths: + # for baudrate in self.baudrates: + # # dataset_glob = self.data_datetime_glob + f"*-*-{int(length)}-*-*-*-*-PAM4-*-*-{int(baudrate/1e9)}.ini" + # dataset_glob = data_datetime_glob + f"-*-*-{int(length)}-*-*-*-*-PAM4-*-*.ini" + # datasets = [f for f in data_dir.glob(dataset_glob)] + # if len(datasets) == 0: + # continue + # self.datasets[length] = {} + # if len(datasets) > 1: + # print( + # f"multiple datasets found for [{length / 1000:.1f} km, {int(baudrate / 1e9)} GBd]. Using the newest dataset." + # ) + # # get newest file from creation date + # datasets.sort(key=lambda x: x.stat().st_ctime) + # self.datasets[length][baudrate] = str(datasets[-1]) + + def load_dataset(self, dataset_path): + if self.checkpoint_dict is None: + raise ValueError("Model must be loaded before dataset") + + if self.dataset_path is None: + self.dataset_path = dataset_path + elif self.dataset_path == dataset_path: + return + + symbols = self.checkpoint_dict["settings"]["data_settings"].symbols + data_size = self.checkpoint_dict["settings"]["data_settings"].output_size + dtype = getattr(torch, self.checkpoint_dict["settings"]["data_settings"].dtype) + drop_first = self.checkpoint_dict["settings"]["data_settings"].drop_first + randomise_polarisations = self.checkpoint_dict["settings"]["data_settings"].randomise_polarisations + polarisations = self.checkpoint_dict["settings"]["data_settings"].polarisations + num_symbols = None + if self.debug: + num_symbols = 1000 + + config_path = Path(dataset_path) + + dataset = util.datasets.FiberRegenerationDataset( + file_path=config_path, + symbols=symbols, + output_dim=data_size, + drop_first=drop_first, + dtype=dtype, + real=not dtype.is_complex, + randomise_polarisations=randomise_polarisations, + polarisations=polarisations, + num_symbols=num_symbols, + # device="cuda" if torch.cuda.is_available() else "cpu", + ) + + self.dataloader = torch.utils.data.DataLoader( + dataset, batch_size=2**14, pin_memory=True, num_workers=24, prefetch_factor=8, shuffle=False + ) + + return self.dataloader.dataset.orig_symbols + # run model + # return results as array: [fiber_in, fiber_out, fiber_out_noisy, regen_out] + + def load_model(self, model_path: str | None = None): + if model_path is None: + self.model = None + self.model_path = None + self.checkpoint_dict = None + return + + path = Path(model_path) + + if self.model_path is None: + self.model_path = path + elif path == self.model_path: + return + + self.dataset_path = None # reset dataset path, as the shape depends on the model + + self.checkpoint_dict = torch.load(path, weights_only=True) + dims = self.checkpoint_dict["model_kwargs"].pop("dims") + self.model = models.regenerator(*dims, **self.checkpoint_dict["model_kwargs"]) + self.model.load_state_dict(self.checkpoint_dict["model_state_dict"]) + + ## Model evaluation + def run_model_evaluation(self, model_path: str, datasets: str | None = None): + self.load_model(model_path) + # iterate over datasets and osnr values: + # load dataset, add noise, run model, return results + # save results to file + self.results = [] + + if datasets is not None: + self.load_datasets_from_file(datasets) + + n_datasets = len(self.datasets) + for i, dataset_path in enumerate(self.datasets): + conf = get_config(dataset_path) + mpath = Path(model_path) + model_base = mpath.stem + print(f"({1+i}/{n_datasets}) Running model {model_base} with dataset {dataset_path.split('/')[-1]}") + + results_path = self.build_path( + dataset_path.split("/")[-1], parent_dir=Path(self.results_dir) / model_base + ) + + orig_symbols = self.load_dataset(dataset_path) + + data, loss = self.run_model() + + metadata = { + "model_path": model_path, + "dataset_path": dataset_path, + "loss": loss, + "sps": conf["glova"]["sps"], + "orig_symbols": orig_symbols + # "config": conf, + # "checkpoint_dict": self.checkpoint_dict, + # "nos": self.dataloader.dataset.num_symbols, + } + + self.save_to_file(results_path, data, **metadata) + self.results.append(str(results_path)) + + results_list_path = self.build_path("results_list", parent_dir=self.results_dir, timestamp="back") + metadata = {"parameters": self.parameters} + data = np.array(self.results, dtype="S") + self.save_to_file(results_list_path, data, **metadata) + + def run_model(self): + loss = 0 + datas = [] + + self.model.eval() + model = self.model.to("cuda" if torch.cuda.is_available() else "cpu") + with torch.no_grad(): + for batch in self.dataloader: + x = batch["x_stacked"] + y = batch["y_stacked"] + fiber_in = batch["plot_target"] + # fiber_out = batch["plot_clean"] + fiber_out = batch["plot_data"] + timestamp = batch["timestamp"] + angle = batch["mean_angle"] + x = x.to("cuda" if torch.cuda.is_available() else "cpu") + angle = angle.to("cuda" if torch.cuda.is_available() else "cpu") + regen = model(x, -angle) + regen = regen.to("cpu") + loss += util.complexNN.complex_mse_loss(regen, y, power=True).item() + # shape: [batch_size, 4] + plot_regen = regen[:, :2] + plot_regen = plot_regen.view(plot_regen.shape[0], -1, 2) + plot_regen = plot_regen[:, plot_regen.shape[1] // 2, :] + + data_out = torch.cat( + ( + fiber_in, + fiber_out, + # fiber_out_noisy, + plot_regen, + timestamp.view(-1, 1), + ), + dim=1, + ) + datas.append(data_out) + + data_out = torch.cat(datas, dim=0).numpy() + + return data_out, loss + + ## File I/O + @staticmethod + def save_to_file(path: str, data: np.ndarray, **metadata: dict): + # create directory if it doesn't exist + path.parent.mkdir(parents=True, exist_ok=True) + + with h5py.File(path, "w") as outfile: + outfile.create_dataset("data", data=data) + for key, value in metadata.items(): + if isinstance(value, dict): + value = json.dumps(model_runner.convert_arrays(value)) + outfile.attrs[key] = value + + @staticmethod + def convert_arrays(dict_in): + """ + convert ndarrays in (nested) dict to lists + """ + dict_out = {} + for key, value in dict_in.items(): + if isinstance(value, dict): + dict_out[key] = model_runner.convert_arrays(value) + elif isinstance(value, np.ndarray): + dict_out[key] = value.tolist() + else: + dict_out[key] = value + return dict_out + + @staticmethod + def load_from_file(path: str): + with h5py.File(path, "r") as infile: + data = infile["data"][:] + metadata = {} + for key in infile.attrs.keys(): + if isinstance(infile.attrs[key], (str, bytes, bytearray)): + try: + metadata[key] = json.loads(infile.attrs[key]) + except json.JSONDecodeError: + metadata[key] = infile.attrs[key] + else: + metadata[key] = infile.attrs[key] + return data, metadata + + ## Utility functions + @staticmethod + def logrange(start, stop, num, endpoint=False): + lower, upper = np.log10((start, stop)) + return np.logspace(lower, upper, num=num, endpoint=endpoint, base=10) + + @staticmethod + def build_path( + *elements, parent_dir: str | Path | None = None, filetype="h5", timestamp: Literal["no", "front", "back"] = "no" + ): + suffix = f".{filetype}" if not filetype.startswith(".") else filetype + if timestamp != "no": + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + if timestamp == "front": + elements = (ts, *elements) + elif timestamp == "back": + elements = (*elements, ts) + path = "_".join(elements) + path += suffix + + if parent_dir is not None: + path = Path(parent_dir) / path + return path + + @staticmethod + def update_range(min, max, n_steps, log): + if log: + range = model_runner.logrange(min, max, n_steps, endpoint=True) + else: + range = np.linspace(min, max, n_steps, endpoint=True) + return range + + +class model_evaluation_result: + def __init__( + self, + *, + length=None, + baudrate=None, + osnr=None, + model_path=None, + dataset_path=None, + loss=None, + sps=None, + **kwargs, + ): + self.length = length + self.baudrate = baudrate + self.osnr = osnr + self.model_path = model_path + self.dataset_path = dataset_path + self.loss = loss + self.sps = sps + + self.sers = None + self.bers = None + self.eye_stats = None + + +class evaluator: + def __init__(self, datasets: list[str]): + """ + datasets: iterable of dataset paths + data_dir: directory containing datasets + """ + self.datasets = datasets + self.results = [] + + def evaluate_datasets(self, plot=False): + ## iterate over datasets + # load dataset + for dataset in self.datasets: + model, dataset_name = dataset.split("/")[-2:] + print(f"\nEvaluating model {model} with dataset {dataset_name}") + data, metadata = model_runner.load_from_file(dataset) + result = model_evaluation_result(**metadata) + + data = self.prepare_data(data, sps=metadata["sps"]) + + try: + sym_x, sym_y = metadata["orig_symbols"] + except (TypeError, KeyError, ValueError): + sym_x, sym_y = None, None + + self.evaluate_eye(data, result, title=dataset.split("/")[-1], plot=False) + self.evaluate_ser_ber(data, result, sym_x, sym_y) + print("BER:") + self.print_dict(result.bers["regen"]) + print() + print("SER:") + self.print_dict(result.sers["regen"]) + print() + + self.results.append(result) + if plot: + plt.show() + + + def evaluate_eye(self, data, result, title=None, plot=False): + eye = util.eye_diagram.eye_diagram( + data, + channel_names=[ + "fiber_in_x", + "fiber_in_y", + # "fiber_out_x", + # "fiber_out_y", + "fiber_out_x", + "fiber_out_y", + "regen_x", + "regen_y", + ], + ) + + eye.analyse() + eye.plot(title=title or "Eye diagram", show=plot) + + result.eye_stats = eye.eye_stats + return eye.eye_stats + ... + + def evaluate_ser_ber(self, data, result, sym_x=None, sym_y=None): + if result.eye_stats is None: + self.evaluate_eye(data, result) + + symbols = [] + sers = {"fiber_out": {"x": None, "y": None}, "regen": {"x": None, "y": None}} + bers = {"fiber_out": {"x": None, "y": None}, "regen": {"x": None, "y": None}} + + for channel_data, stats in zip(data, result.eye_stats): + timestamps = channel_data[0] + dat = channel_data[1] + + channel_name = stats["channel_name"] + if stats["success"]: + thresholds = stats["thresholds"] + time_midpoint = stats["time_midpoint"] + else: + if channel_name.endswith("x"): + thresholds = result.eye_stats[0]["thresholds"] + time_midpoint = result.eye_stats[0]["time_midpoint"] + elif channel_name.endswith("y"): + thresholds = result.eye_stats[1]["thresholds"] + time_midpoint = result.eye_stats[1]["time_midpoint"] + else: + levels = np.linspace(np.min(dat), np.max(dat), 4) + thresholds = util.eye_diagram.eye_diagram.calculate_thresholds(levels) + time_midpoint = 1.0 + + # time_offset = time_midpoint - 0.5 + # # time_offset = 0 + + # index_offset = np.argmin(np.abs((timestamps - time_offset) % 1.0)) + + nos = len(timestamps) // result.sps + + # idx = np.arange(index_offset, len(timestamps), result.sps).astype(int) + + # if time_offset < 0: + # idx = np.insert(idx, 0, 0) + + idx = list(range(0,len(timestamps),result.sps)) + + idx = idx[:nos] + + data_sampled = dat[idx] + detected_symbols = self.detect_symbols(data_sampled, thresholds) + + symbols.append({"channel_name": channel_name, "symbols": detected_symbols}) + + symbols_x_gt = sym_x or symbols[0]["symbols"] + symbols_y_gt = sym_y or symbols[1]["symbols"] + + symbols_x_fiber_out = symbols[2]["symbols"] + symbols_y_fiber_out = symbols[3]["symbols"] + + symbols_x_regen = symbols[4]["symbols"] + symbols_y_regen = symbols[5]["symbols"] + + sers["fiber_out"]["x"], bers["fiber_out"]["x"] = self.calculate_ser_ber(symbols_x_gt, symbols_x_fiber_out) + sers["fiber_out"]["y"], bers["fiber_out"]["y"] = self.calculate_ser_ber(symbols_y_gt, symbols_y_fiber_out) + sers["regen"]["x"], bers["regen"]["x"] = self.calculate_ser_ber(symbols_x_gt, symbols_x_regen) + sers["regen"]["y"], bers["regen"]["y"] = self.calculate_ser_ber(symbols_y_gt, symbols_y_regen) + + result.sers = sers + result.bers = bers + + @staticmethod + def calculate_ser_ber(symbols_gt, symbols): + # levels = 4 + # symbol difference -> bit error count + # |rx - tx| = 0 -> 0 + # |rx - tx| = 1 -> 1 + # |rx - tx| = 2 -> 2 + # |rx - tx| = 3 -> 1 + # assuming gray coding -> 0: 00, 1: 01, 2: 11, 3: 10 + bec_map = {0: 0, 1: 1, 2: 2, 3: 1, np.nan: 2} + + ser = {} + ber = {} + ser["n_symbols"] = len(symbols_gt) + ser["n_errors"] = np.sum(symbols != symbols_gt) + ser["total"] = float(ser["n_errors"] / ser["n_symbols"]) + + bec = np.vectorize(bec_map.get)(np.abs(symbols - symbols_gt)) + bit_errors = np.sum(bec) + + ber["n_bits"] = len(symbols_gt) * 2 + ber["n_errors"] = bit_errors + ber["total"] = float(ber["n_errors"] / ber["n_bits"]) + + return ser, ber + + @staticmethod + def print_dict(d: dict, indent=2, logarithmic=False, level=0): + for key, value in d.items(): + if isinstance(value, dict): + print(f"{' ' * indent * level}{key}:") + evaluator.print_dict(value, indent=indent, logarithmic=logarithmic, level=level + 1) + else: + if isinstance(value, float): + if logarithmic: + if value == 0: + value = -np.inf + else: + value = np.log10(value) + print(f"{' ' * indent * level}{key}: {value:.2e}\t", end="") + else: + print(f"{' ' * indent * level}{key}: {value}\t", end="") + print() + + @staticmethod + def detect_symbols(samples, thresholds=None): + thresholds = (1 / 6, 3 / 6, 5 / 6) if thresholds is None else thresholds + thresholds = (-np.inf, *thresholds, np.inf) + symbols = np.digitize(samples, thresholds) - 1 + return symbols + + @staticmethod + def prepare_data(data, sps=None): + data = data.transpose(1, 0) + timestamps = data[-1].real + data = data[:-1] + if sps is not None: + timestamps /= sps + + # data = np.stack( + # ( + # *data[0:2], # fiber_in_x, fiber_in_y + # # *data_[2:4], # fiber_out_x, fiber_out_y + # *data[4:6], # fiber_out_noisy_x, fiber_out_noisy_y + # *data[6:8], # regen_out_x, regen_out_y + # ), + # axis=0, + # ) + + data_eye = [] + for channel_values in data: + channel_values = np.square(np.abs(channel_values)) + data_eye.append(np.stack((timestamps, channel_values), axis=0)) + + data_eye = np.stack(data_eye, axis=0) + + return data_eye + + +def generate_data(parameters, runner=None): + runner = runner or model_runner() + for param in parameters: + runner.update_parameter_range(*param) + runner.generate_iterations() + print(f"{runner.iter.length} parameter combinations") + runner.generate_datasets() + + return runner + +if __name__ == "__main__": + model_path = ".models/best_20250110_191149.tar" # D 17, OSNR 100, delta_beta 0.14, baud 10e9 + + parameters = ( + # name, config keys, (min, max), n_steps, log + # ("D", ("fiber", "d"), (28,30), 3, False), + # ("S", ("fiber", "s"), (0, 0.058), 2, False), + ("OSNR", ("signal", "osnr"), (20, 40), 5, False), + # ("PMD", ("fiber", "max_delta_beta"), (0, 0.28), 3, False), + # ("Baud", ("glova", "symbolrate"), (10e9, 100e9), 3, True), + ) + + datasets = None + results = None + + # datasets = "tolerance_results/datasets/datasets_list_20250110_223337.h5" + results = "tolerance_results/datasets/results_list_20250110_232639.h5" + + runner = model_runner() + # generate_data(parameters, runner) + + + if results is None: + if datasets is None: + generate_data(parameters, runner) + else: + runner.load_datasets_from_file(datasets) + print(f"{len(runner.datasets)} loaded") + + runner.run_model_evaluation(model_path) + else: + runner.load_results_from_file(results) + + # print(runner.parameters) + # print(runner.results) + + eval = evaluator(runner.results) + eval.evaluate_datasets(plot=True) diff --git a/src/single-core-regen/util/complexNN.py b/src/single-core-regen/util/complexNN.py index 9e2cfe0..f488f5b 100644 --- a/src/single-core-regen/util/complexNN.py +++ b/src/single-core-regen/util/complexNN.py @@ -481,6 +481,15 @@ class Identity(nn.Module): def forward(self, x): return x + +class phase_shift(nn.Module): + def __init__(self, size): + super(phase_shift, self).__init__() + self.size = size + self.phase = nn.Parameter(torch.rand(size)) + + def forward(self, x): + return x * torch.exp(1j*self.phase) class PowRot(nn.Module): @@ -531,19 +540,19 @@ def angle_mse_loss(x: torch.Tensor, target: torch.Tensor): class EOActivation(nn.Module): def __init__(self, size=None): - # 10.1109/SiPhotonics60897.2024.10543376 + # 10.1109/JSTQE.2019.2930455 super(EOActivation, self).__init__() if size is None: raise ValueError("Size must be specified") self.size = size - self.alpha = nn.Parameter(torch.ones(size)) - self.V_bias = nn.Parameter(torch.ones(size)) - self.gain = nn.Parameter(torch.ones(size)) + self.alpha = nn.Parameter(torch.rand(size)) + self.V_bias = nn.Parameter(torch.rand(size)) + self.gain = nn.Parameter(torch.rand(size)) # if bias: # self.phase_bias = nn.Parameter(torch.zeros(size)) # else: # self.register_buffer("phase_bias", torch.zeros(size)) - self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi) + # self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi) self.register_buffer("responsivity", torch.ones(size)*0.9) self.register_buffer("V_pi", torch.ones(size)*3) @@ -551,17 +560,17 @@ class EOActivation(nn.Module): def reset_weights(self): if "alpha" in self._parameters: - self.alpha.data = torch.ones(self.size)*0.5 + self.alpha.data = torch.rand(self.size) if "V_pi" in self._parameters: - self.V_pi.data = torch.ones(self.size)*3 + self.V_pi.data = torch.rand(self.size)*3 if "V_bias" in self._parameters: - self.V_bias.data = torch.zeros(self.size) + self.V_bias.data = torch.randn(self.size) if "gain" in self._parameters: - self.gain.data = torch.ones(self.size) + self.gain.data = torch.rand(self.size) if "responsivity" in self._parameters: self.responsivity.data = torch.ones(self.size)*0.9 - if "bias" in self._parameters: - self.phase_bias.data = torch.zeros(self.size) + # if "bias" in self._parameters: + # self.phase_bias.data = torch.zeros(self.size) def forward(self, x: torch.Tensor): phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8) @@ -570,12 +579,11 @@ class EOActivation(nn.Module): return ( 1j * torch.sqrt(1 - self.alpha) - * torch.exp(-0.5j * (intermediate + self.phase_bias)) + * torch.exp(-0.5j * intermediate) * torch.cos(0.5 * intermediate) * x ) - class Pow(nn.Module): """ implements the activation function @@ -716,6 +724,7 @@ __all__ = [ MZISingle, EOActivation, photodiode, + phase_shift, # SaturableAbsorberLambertW, # SaturableAbsorber, # SpreadLayer, diff --git a/src/single-core-regen/util/datasets.py b/src/single-core-regen/util/datasets.py index 4967639..fe6d01a 100644 --- a/src/single-core-regen/util/datasets.py +++ b/src/single-core-regen/util/datasets.py @@ -1,4 +1,5 @@ from pathlib import Path +import h5py import torch from torch.utils.data import Dataset @@ -24,8 +25,22 @@ import multiprocessing as mp # def __len__(self): # return len(self.indices) +def load_from_file(datapath): + if str(datapath).endswith('.h5'): + symbols = None + with h5py.File(datapath, "r") as infile: + data = infile["data"][:] + try: + symbols = infile["symbols"][:] + except KeyError: + pass + else: + symbols = None + data = np.load(datapath) + return data, symbols -def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None): + +def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=1, device=None, dtype=None): filepath = Path(config_path) filepath = filepath.parent.glob(filepath.name) config = configparser.ConfigParser() @@ -40,15 +55,21 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals if symbols is None: symbols = int(config["glova"]["nos"]) - skipfirst + + data, orig_symbols = load_from_file(datapath) - data = np.load(datapath)[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)] + data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)] + orig_symbols = orig_symbols[skipfirst:symbols+skipfirst] timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps)) - if normalize: - # square gets normalized to 1, as the power is (proportional to) the square of the amplitude - a, b, c, d = np.square(data.T) - a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d)) - data = np.sqrt(np.array([a, b, c, d]).T) + data *= np.sqrt(normalize) + + # if normalize: + # # square gets normalized to 1, as the power is (proportional to) the square of the amplitude + # a, b, c, d = data.T + # a, b, c, d = a - np.min(np.abs(a)), b - np.min(np.abs(b)), c - np.min(np.abs(c)), d - np.min(np.abs(d)) + # a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d)) + # data = np.array([a, b, c, d]).T if real: data = np.abs(data) @@ -59,7 +80,7 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals data = torch.tensor(data, device=device, dtype=dtype) - return data, config + return data, config, orig_symbols def roll_along(arr, shifts, dim): @@ -114,7 +135,8 @@ class FiberRegenerationDataset(Dataset): dtype: torch.dtype = None, real: bool = False, device=None, - osnr: float = None, + # osnr: float|None = None, + polarisations = None, randomise_polarisations: bool = False, repeat_randoms: int = 1, **kwargs, @@ -151,65 +173,50 @@ class FiberRegenerationDataset(Dataset): self.randomise_polarisations = randomise_polarisations - faux = kwargs.pop("faux", False) - - if faux: - data_raw = np.array( - [[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)], - dtype=np.complex128, + data_raw = None + self.config = None + files = [] + self.orig_symbols = None + for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]: + data, config, orig_syms = load_data( + file_path, + skipfirst=drop_first, + symbols=kwargs.get("num_symbols", None), + real=real, + normalize=1000, + device=device, + dtype=dtype, ) - data_raw = torch.tensor(data_raw, device=device, dtype=dtype) - timestamps = torch.arange(12800) - - data_raw = torch.concatenate([data_raw, timestamps.reshape(-1, 1)], axis=-1) - - self.config = { - "data": {"dir": '"."', "npy_dir": '"."', "file": "faux"}, - "glova": {"sps": 128}, - } - else: - data_raw = None - self.config = None - files = [] - for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]: - data, config = load_data( - file_path, - skipfirst=drop_first, - symbols=kwargs.get("num_symbols", None), - real=real, - normalize=True, - device=device, - dtype=dtype, - ) - if data_raw is None: - data_raw = data + if orig_syms is not None: + if self.orig_symbols is None: + self.orig_symbols = orig_syms else: - data_raw = torch.cat([data_raw, data], dim=0) - if self.config is None: - self.config = config - else: - assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same" - files.append(config["data"]["file"].strip('"')) - self.config["data"]["file"] = str(files) + self.orig_symbols = np.concat((self.orig_symbols, orig_syms), axis=-1) + + if data_raw is None: + data_raw = data + else: + data_raw = torch.cat([data_raw, data], dim=0) + if self.config is None: + self.config = config + else: + assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same" + files.append(config["data"]["file"].strip('"')) + self.config["data"]["file"] = str(files) - # if polarisations is not None: - # self.angles = torch.tensor(polarisations).repeat(len(data_raw), 1) - # for i, angle in enumerate(torch.tensor(np.array(polarisations))): - # data_raw_copy = data_raw.clone() - # if angle == 0: - # continue - # sine = torch.sin(angle) - # cosine = torch.cos(angle) - # data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine - # data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine - # if i == 0: - # data_raw = data_raw_copy - # else: - # data_raw = torch.cat([data_raw, data_raw_copy], dim=0) + # if polarisations is not None: + # data_raw_clone = data_raw.clone() + # # rotate the polarisation by 180 degrees + # data_raw_clone[2, :] *= -1 + # data_raw_clone[3, :] *= -1 + # data_raw = torch.cat([data_raw, data_raw_clone], dim=0) + + self.polarisations = bool(polarisations) self.device = data_raw.device self.samples_per_symbol = int(self.config["glova"]["sps"]) + # self.num_symbols = int(self.config["glova"]["nos"]) self.samples_per_slice = int(symbols * self.samples_per_symbol) self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol @@ -290,6 +297,34 @@ class FiberRegenerationDataset(Dataset): fiber_in = torch.cat([fiber_in, timestamps.unsqueeze(0)], dim=0) fiber_out = torch.cat([fiber_out, timestamps.unsqueeze(0)], dim=0) + # fiber_out: [E_out_x, E_out_y, timestamps] + + # add noise related to amplification necessary due to splitting of the signal + gain_lin = output_dim*2 + edfa_nf = float(self.config["signal"]["edfa_nf"]) + nf_lin = 10**(edfa_nf/10) + f0 = float(self.config["glova"]["f0"]) + + noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9 + + noise = torch.randn_like(fiber_out[:2, :]) + noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1)) + noise = noise * torch.sqrt(noise_add / noise_power) + fiber_out[:2, :] += noise + + + + + # if osnr is None: + # noisy = fiber_out[:2, :] + # else: + # noisy = self.add_noise(fiber_out[:2, :], osnr) + + # fiber_out = torch.cat([fiber_out, noisy], dim=0) + + # fiber_out: [E_out_x, E_out_y, timestamps, E_out_x_noisy, E_out_y_noisy] + + if repeat_randoms > 1: fiber_in = fiber_in.repeat(1, 1, repeat_randoms) fiber_out = fiber_out.repeat(1, 1, repeat_randoms) @@ -298,28 +333,34 @@ class FiberRegenerationDataset(Dataset): repeat_randoms = 1 if self.randomise_polarisations: - angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms), 2) * torch.pi + angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms, device=fiber_out.device), 2) * torch.pi # start_angle = torch.rand(1) * 2 * torch.pi # angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk # self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi else: - angles = torch.zeros(data_raw.shape[-1]) + angles = torch.zeros(data_raw.shape[-1], device=fiber_out.device) sin = torch.sin(angles) cos = torch.cos(angles) rot = torch.stack([torch.stack([cos, -sin], dim=1), torch.stack([sin, cos], dim=1)], dim=2) data_rot = torch.bmm(fiber_out[:2, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T + # data_rot_noisy = torch.bmm(fiber_out[3:5, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T fiber_out = torch.cat((fiber_out, data_rot), dim=0) fiber_out = torch.cat([fiber_out, angles.unsqueeze(0)], dim=0) - if osnr is not None: - popt = torch.mean(fiber_out[:2, :, :].abs().flatten(), dim=-1) - noise = torch.randn_like(fiber_out[:2, :, :]) - pn = torch.mean(noise.abs().flatten(), dim=-1) - noise = noise * (popt / pn) * 10 ** (-osnr / 20) - fiber_out[:2, :, :] = torch.add(fiber_out[:2, :, :], noise) + # fiber_in: + # 0 E_in_x, + # 1 E_in_y, + # 2 timestamps + + # fiber_out: + # 0 E_out_x, + # 1 E_out_y, + # 2 timestamps, + # 3 E_out_x_rot, + # 4 E_out_y_rot, + # 5 angle - # data_raw = torch.cat([data_raw, timestamps_doubled], dim=1) @@ -349,6 +390,22 @@ class FiberRegenerationDataset(Dataset): def __len__(self): return self.fiber_in.shape[0] + + def add_noise(self, data, osnr): + osnr_lin = 10**(osnr/10) + popt = torch.mean(data.abs().square().squeeze(), dim=-1) + noise = torch.randn_like(data) + pn = torch.mean(noise.abs().square().squeeze(), dim=-1) + + mult = torch.sqrt(popt/(pn*osnr_lin)) + mult = mult * torch.eye(popt.shape[0], device=mult.device) + mult = mult.to(dtype=noise.dtype) + + noise = mult @ noise + pn = torch.mean(noise.abs().square().squeeze(), dim=-1) + noisy = data + noise + return noisy + def __getitem__(self, idx): if isinstance(idx, slice): @@ -357,14 +414,19 @@ class FiberRegenerationDataset(Dataset): # fiber in: [E_in_x, E_in_y, timestamps] # fiber out: [E_out_x, E_out_y, timestamps, E_out_x_rot, E_out_y_rot, angle] + # if self.polarisations: + output_dim = self.output_dim // 2 + self.output_dim = output_dim * 2 + fiber_in = self.fiber_in[idx].squeeze() fiber_out = self.fiber_out[idx].squeeze() - fiber_in = fiber_in[..., : fiber_in.shape[-1] // self.output_dim * self.output_dim] - fiber_out = fiber_out[..., : fiber_out.shape[-1] // self.output_dim * self.output_dim] + fiber_in = fiber_in[..., : fiber_in.shape[-1] // output_dim * output_dim] + fiber_out = fiber_out[..., : fiber_out.shape[-1] // output_dim * output_dim] + + fiber_in = fiber_in.view(fiber_in.shape[0], output_dim, -1) + fiber_out = fiber_out.view(fiber_out.shape[0], output_dim, -1) - fiber_in = fiber_in.view(fiber_in.shape[0], self.output_dim, -1) - fiber_out = fiber_out.view(fiber_out.shape[0], self.output_dim, -1) # data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim] @@ -372,11 +434,36 @@ class FiberRegenerationDataset(Dataset): # angle = self.angles[idx] - center_angle = fiber_out[5, self.output_dim // 2, 0] + # fiber_in: + # 0 E_in_x, + # 1 E_in_y, + # 2 timestamps + + # fiber_out: + # 0 E_out_x, + # 1 E_out_y, + # 2 timestamps, + # 3 E_out_x_rot, + # 4 E_out_y_rot, + # 5 angle + + center_angle = fiber_out[0, output_dim // 2, 0] angles = fiber_out[5, :, 0] - plot_data = fiber_out[:2, self.output_dim // 2, 0].detach().clone() - plot_data_rot = fiber_out[3:5, self.output_dim // 2, 0].detach().clone() - data = fiber_out[3:5, :, 0] + plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone() + plot_data_rot = fiber_out[3:5, output_dim // 2, 0].detach().clone() + data = fiber_out[0:2, :, 0] + # fiber_out_plot_clean = fiber_out[:2, output_dim // 2, 0].detach().clone() + + + # if self.polarisations: + # rot = int(np.random.randint(2)*2-1) + # pol_flipped_data[0:1, :] = rot*data[0, :] + # pol_flipped_data[1, :] = rot*data[1, :] + # plot_data_rot[0] = rot*plot_data_rot[0] + # plot_data_rot[1] = rot*plot_data_rot[1] + # center_angle = center_angle + (rot - 1) * torch.pi/2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0 + # angles = angles + (rot - 1) * torch.pi/2 + # if self.randomise_polarisations: # data = data.mT @@ -389,16 +476,27 @@ class FiberRegenerationDataset(Dataset): # angle = torch.zeros_like(angle) # for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter) - angle_data = fiber_out[:2, :, :].reshape(2, -1).mean(dim=1).repeat(1, self.output_dim) - angle_data2 = self.complex_max(fiber_out[:2, :, :].reshape(2, -1)).repeat(1, self.output_dim) + # angle_data = fiber_out[:2, :, :].reshape(2, -1).mean(dim=1).repeat(1, output_dim) + # angle_data2 = self.complex_max(fiber_out[:2, :, :].reshape(2, -1)).repeat(1, output_dim) # sop = self.polarimeter(plot_data) # angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1) # angle = data_slice[1, 3, self.output_dim // 2, 0].real - target = fiber_in[:2, self.output_dim // 2, 0] - plot_target = fiber_in[:2, self.output_dim // 2, 0].detach().clone() - target_timestamp = fiber_in[2, self.output_dim // 2, 0].real + target = fiber_in[:2, output_dim // 2, 0] + plot_target = fiber_in[:2, output_dim // 2, 0].detach().clone() + target_timestamp = fiber_in[2, output_dim // 2, 0].real ... + if self.polarisations: + rot = int(np.random.randint(2)*2-1) + data = rot*data + target = rot*target + plot_data_rot = rot*plot_data_rot + center_angle = center_angle + (rot - 1) * torch.pi/2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0 + angles = angles + (rot - 1) * torch.pi/2 + + pol_flipped_data = -data + pol_flipped_target = -target + # data_timestamps = data[-1,:].real # data = data[:-1, :] # target_timestamp = target[-1].real @@ -407,13 +505,15 @@ class FiberRegenerationDataset(Dataset): # transpose to interleave the x and y data in the output tensor data = data.transpose(0, 1).flatten().squeeze() - angle_data = angle_data.transpose(0, 1).flatten().squeeze() - angle_data2 = angle_data2.transpose(0,1).flatten().squeeze() + pol_flipped_data = pol_flipped_data.transpose(0, 1).flatten().squeeze() + # angle_data = angle_data.transpose(0, 1).flatten().squeeze() + # angle_data2 = angle_data2.transpose(0,1).flatten().squeeze() center_angle = center_angle.flatten().squeeze() angles = angles.flatten().squeeze() # data_timestamps = data_timestamps.flatten().squeeze() # target = target.transpose(0,1).flatten().squeeze() target = target.flatten().squeeze() + pol_flipped_target = pol_flipped_target.flatten().squeeze() target_timestamp = target_timestamp.flatten().squeeze() plot_target = plot_target.flatten().squeeze() plot_data = plot_data.flatten().squeeze() @@ -421,17 +521,22 @@ class FiberRegenerationDataset(Dataset): return { "x": data, + "x_flipped": pol_flipped_data, + "x_stacked": torch.cat([data, pol_flipped_data], dim=-1), "y": target, - "center_angle": center_angle, - "angles": angles, + "y_flipped": pol_flipped_target, + "y_stacked": torch.cat([target, pol_flipped_target], dim=-1), + # "center_angle": center_angle, + # "angles": angles, "mean_angle": angles.mean(), # "sop": sop, - "angle_data": angle_data, - "angle_data2": angle_data2, + # "angle_data": angle_data, + # "angle_data2": angle_data2, "timestamp": target_timestamp, "plot_target": plot_target, "plot_data": plot_data, "plot_data_rot": plot_data_rot, + # "plot_clean": fiber_out_plot_clean, } def complex_max(self, data, dim=-1): diff --git a/src/single-core-regen/util/eye_diagram.py b/src/single-core-regen/util/eye_diagram.py index 61115e1..74ea0fe 100644 --- a/src/single-core-regen/util/eye_diagram.py +++ b/src/single-core-regen/util/eye_diagram.py @@ -82,7 +82,6 @@ class eye_diagram: self.vertical_bins = vertical_bins self.multi_threaded = multithreaded self.eye_built = False - self.analyse() def generate_eye_data(self): self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False) @@ -126,6 +125,7 @@ class eye_diagram: rows = int(np.ceil(self.channels / cols)) fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False) fig.suptitle(title) + fig.tight_layout() ax = np.atleast_1d(ax).transpose().flatten() for i in range(self.channels): ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}") @@ -147,19 +147,21 @@ class eye_diagram: yspan = ymax - ymin ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan)) if stats and self.eye_stats[i]["success"]: - # add min_area above the plot - ax[i].annotate( - f"Min Area: {self.eye_stats[i]['min_area']:.2e}", - xy=(0.05, ymax + 0.05 * yspan), - # xycoords="axes fraction", - ha="left", - va="center", - bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"), - ) + # # add min_area above the plot + # ax[i].annotate( + # f"Min Area: {self.eye_stats[i]['min_area']:.2e}", + # xy=(0.05, ymax + 0.05 * yspan), + # # xycoords="axes fraction", + # ha="left", + # va="center", + # bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"), + # ) if all_stats: ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--") - ax[i].set_yticks(self.eye_stats[i]["levels"]) + y_ticks = (*self.eye_stats[i]["levels"],*self.eye_stats[i]["thresholds"]) + # y_ticks = np.sort(y_ticks) + ax[i].set_yticks(y_ticks) # add arrows for amplitudes for j in range(len(self.eye_stats[i]["amplitudes"])): ax[i].annotate( @@ -193,35 +195,35 @@ class eye_diagram: except (ValueError, IndexError): pass # add arrows for eye widths - for j in range(len(self.eye_stats[i]["widths"])): - try: - left = np.max(self.eye_stats[i]["time_clusters"][j][0]) - right = np.min(self.eye_stats[i]["time_clusters"][j][1]) - vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2 + # for j in range(len(self.eye_stats[i]["widths"])): + # try: + # left = np.max(self.eye_stats[i]["time_clusters"][j][0]) + # right = np.min(self.eye_stats[i]["time_clusters"][j][1]) + # vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2 - ax[i].annotate( - "", - xy=(left, vertical), - xytext=(right, vertical), - arrowprops=dict(arrowstyle="<->", facecolor="black"), - ) - ax[i].annotate( - f"{self.eye_stats[i]['widths'][j]:.2e}", - xy=((left + right) / 2 - 0.15, vertical + 0.01), - bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"), - ) - except (ValueError, IndexError): - pass + # ax[i].annotate( + # "", + # xy=(left, vertical), + # xytext=(right, vertical), + # arrowprops=dict(arrowstyle="<->", facecolor="black"), + # ) + # ax[i].annotate( + # f"{self.eye_stats[i]['widths'][j]:.2e}", + # xy=((left + right) / 2 - 0.15, vertical + 0.01), + # bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"), + # ) + # except (ValueError, IndexError): + # pass - # add area - for j in range(len(self.eye_stats[i]["areas"])): - horizontal = self.eye_stats[i]["time_midpoint"] - vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2 - ax[i].annotate( - f"{self.eye_stats[i]['areas'][j]:.2e}", - xy=(horizontal + 0.035, vertical - 0.07), - bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"), - ) + # # add area + # for j in range(len(self.eye_stats[i]["areas"])): + # horizontal = self.eye_stats[i]["time_midpoint"] + # vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2 + # ax[i].annotate( + # f"{self.eye_stats[i]['areas'][j]:.2e}", + # xy=(horizontal + 0.035, vertical - 0.07), + # bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"), + # ) fig.tight_layout() @@ -229,6 +231,12 @@ class eye_diagram: plt.show() return fig + @staticmethod + def calculate_thresholds(levels): + ret = np.cumsum(levels, dtype=float) + ret[2:] = ret[2:] - ret[:-2] + return ret[1:]/2 + def analyse_single(self, data, index): warnings.filterwarnings("error") eye_stats = {} @@ -238,12 +246,15 @@ class eye_diagram: time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels) - eye_stats["time_midpoint"] = (time_bounds[0] + time_bounds[1]) / 2 + eye_stats["time_midpoint_calc"] = (time_bounds[0] + time_bounds[1]) / 2 + eye_stats["time_midpoint"] = 1.0 eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels( data, approx_levels, time_bounds ) + eye_stats["thresholds"] = self.calculate_thresholds(eye_stats["levels"]) + eye_stats["amplitudes"] = np.diff(eye_stats["levels"]) eye_stats["heights"] = eye_diagram.calculate_eye_heights( @@ -260,22 +271,23 @@ class eye_diagram: # if not (np.max(eye_stats['time_clusters'][j][0]) < eye_stats["time_midpoint"] < np.min(eye_stats['time_clusters'][j][1])): # raise ValueError - eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"] - eye_stats["mean_area"] = np.mean(eye_stats["areas"]) - eye_stats["min_area"] = np.min(eye_stats["areas"]) + # eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"] + # eye_stats["mean_area"] = np.mean(eye_stats["areas"]) + # eye_stats["min_area"] = np.min(eye_stats["areas"]) eye_stats["success"] = True except (RuntimeWarning, UserWarning, ValueError): eye_stats["success"] = False - eye_stats["time_midpoint"] = 0 - eye_stats["levels"] = np.zeros(self.n_levels) - eye_stats["amplitude_clusters"] = [] - eye_stats["amplitudes"] = np.zeros(self.n_levels - 1) - eye_stats["heights"] = np.zeros(self.n_levels - 1) - eye_stats["widths"] = np.zeros(self.n_levels - 1) - eye_stats["areas"] = np.zeros(self.n_levels - 1) - eye_stats["mean_area"] = 0 - eye_stats["min_area"] = 0 + eye_stats["time_midpoint"] = None + eye_stats["levels"] = None + eye_stats["thresholds"] = None + eye_stats["amplitude_clusters"] = None + eye_stats["amplitudes"] = None + eye_stats["heights"] = None + eye_stats["widths"] = None + # eye_stats["areas"] = np.zeros(self.n_levels - 1) + # eye_stats["mean_area"] = 0 + # eye_stats["min_area"] = 0 warnings.resetwarnings() return eye_stats @@ -441,7 +453,8 @@ if __name__ == "__main__": data = generate_sample_data(length, noise=0.005) eye = eye_diagram(data, horizontal_bins=256, vertical_bins=256) - attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths", "area", "mean_area", "min_area") + eye.analyse() + attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area") for i, channel in enumerate(eye.eye_stats): print(f"Channel {i}") print_data = {attr: channel[attr] for attr in attrs} diff --git a/src/single-core-regen/util/plot.py b/src/single-core-regen/util/plot.py index fe1f407..aebea86 100644 --- a/src/single-core-regen/util/plot.py +++ b/src/single-core-regen/util/plot.py @@ -1,6 +1,9 @@ import matplotlib.pyplot as plt import numpy as np -from .datasets import load_data +if __name__ == "__main__": + from datasets import load_data +else: + from .datasets import load_data def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0, width=2, alpha=None, complex=False, show=True): """Plot an eye diagram for the data given by filepath. @@ -20,6 +23,7 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0 raise ValueError("Either path or data and sps must be given.") if path is not None: data, config = load_data(path, skipfirst, symbols) + data = data.detach().cpu().numpy()[:, :4] sps = int(config["glova"]["sps"]) if sps is None: raise ValueError("sps not set.") @@ -71,3 +75,6 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0 plt.show() return fig + +if __name__ == "__main__": + eye(path="data/20241229-163838-128-16384-50000-0-0.2-16.8-0.058-PAM4-0-0.16.ini", symbols=1000, width=2, alpha=0.1, complex=False) \ No newline at end of file