diff --git a/.gitignore b/.gitignore index 1023046..1374361 100644 --- a/.gitignore +++ b/.gitignore @@ -163,4 +163,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -tolerance_results/datasets/* +tolerance_results/* +data/* diff --git a/data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini b/data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini new file mode 100644 index 0000000..62ea501 --- /dev/null +++ b/data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fcbdaffa211d6b0b44b3ae1c66645999e95901bfdb2fffee4c45e34a0d901ee1 +size 649 diff --git a/data/npys/6789fdea2609799ef2e975907625b79a.h5 b/data/npys/6789fdea2609799ef2e975907625b79a.h5 new file mode 100644 index 0000000..d7f272d --- /dev/null +++ b/data/npys/6789fdea2609799ef2e975907625b79a.h5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1df90745cc2e6d4b0ad964fca2de1441e6e0b4b8345fbb0fbc1ffe9820674269 +size 134481920 diff --git a/notes/tolerance_testing.md b/notes/tolerance_testing.md new file mode 100644 index 0000000..923fb87 --- /dev/null +++ b/notes/tolerance_testing.md @@ -0,0 +1,59 @@ +# Baseline Models + +## a) D+S, pol_error 0, ortho_error 0, DGD 0 + +dataset + +```raw + data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini +``` + +model + +```raw + .models/best_20250118_225918.tar +``` + +## b) D+S, pol_error 0.4, ortho_error 0, DGD 0 + +dataset + +```raw + data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini +``` + +model + +```raw + .models/best_20250116_214816.tar +``` + +## c) D+S, pol_error 0, ortho_error 0.1, DGD 0 + +dataset + +```raw + data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini +``` + +model + +```raw + .models/best_20250117_122319.tar +``` + +## d) D+S, pol_error 0, ortho_error 0, DGD 10ps (1 T_sym) + +birefringence angle pi/2 (worst case) + +dataset + +```raw + data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini +``` + +model + +```raw + .models/best_20250117_144001.tar +``` diff --git a/pypho b/pypho index dd015f4..e44fc47 160000 --- a/pypho +++ b/pypho @@ -1 +1 @@ -Subproject commit dd015f48523d73d20391fd006f7116f7bc8c06e2 +Subproject commit e44fc477fed3c81f5255eef71882e24fcefcedc2 diff --git a/src/single-core-regen/hypertraining/models.py b/src/single-core-regen/hypertraining/models.py index 756738c..07e07d3 100644 --- a/src/single-core-regen/hypertraining/models.py +++ b/src/single-core-regen/hypertraining/models.py @@ -164,10 +164,14 @@ class regenerator(Module): module = act_function(size=dims[-1], **act_func_kwargs) self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module) + module = Scale(size=dims[-1]) + self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module) + if self.rotation: module = rotate() self.add_module("rotate", module) + # module = Scale(size=dims[-1]) # self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module) diff --git a/src/single-core-regen/hypertraining/settings.py b/src/single-core-regen/hypertraining/settings.py index 1e144ff..2ea54b8 100644 --- a/src/single-core-regen/hypertraining/settings.py +++ b/src/single-core-regen/hypertraining/settings.py @@ -18,9 +18,11 @@ class DataSettings: shuffle: bool = True in_out_delay: float = 0 xy_delay: tuple | float | int = 0 - drop_first: int = 1000 + drop_first: int = 64 + drop_last: int = 64 train_split: float = 0.8 polarisations: tuple | list = (0,) + # cross_pol_interference: float = 0 randomise_polarisations: bool = False osnr: float | int = None seed: int = None @@ -93,6 +95,12 @@ class ModelSettings: """ +def _early_stop_default_kwargs(): + return { + "threshold": 1e-05, + "plateau": 25, + } + @dataclass class OptimizerSettings: optimizer: tuple | str = ("Adam", "RMSprop", "SGD") @@ -101,6 +109,9 @@ class OptimizerSettings: scheduler: str | None = None scheduler_kwargs: dict | None = None + early_stopping: bool = False + early_stop_kwargs: dict = field(default_factory=_early_stop_default_kwargs) + """ change to: diff --git a/src/single-core-regen/hypertraining/training.py b/src/single-core-regen/hypertraining/training.py index ff54316..f487e75 100644 --- a/src/single-core-regen/hypertraining/training.py +++ b/src/single-core-regen/hypertraining/training.py @@ -4,6 +4,7 @@ from pathlib import Path import random import matplotlib from matplotlib.colors import LinearSegmentedColormap +from mpl_toolkits.axes_grid1 import make_axes_locatable import torch.nn.utils.parametrize try: @@ -46,13 +47,72 @@ from .settings import ( PytorchSettings, ) +from cmcrameri import cm +# from matplotlib import colors as mcolors +# alpha_map = mcolors.LinearSegmentedColormap( +# 'alphamap', +# { +# 'red': [(0, 0, 0), (1, 0, 0)], +# 'green': [(0, 0, 0), (1, 0, 0)], +# 'blue': [(0, 0, 0), (1, 0, 0)], +# 'alpha': [ +# (0, 1, 1), +# # (0.2, 0.2, 0.1), +# (1, 0, 0) +# ] +# } +# ) +# alpha_map.set_bad(color="#AAAAAA") + +def pad_to_size(array, size): + if not hasattr(size, "__len__"): + size = (size, size) + + left = ( + (size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0 + ) + right = ( + (size[0] - array.shape[0]) // 2 if size[0] is not None else 0 + ) + top = ( + (size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0 + ) + bottom = ( + (size[1] - array.shape[1]) // 2 if size[1] is not None else 0 + ) + + array: np.ndarray = array + if array.ndim == 2: + return np.pad( + array, + ( + (left, right), + (top, bottom), + ), + constant_values=(np.nan, np.nan), + ) + elif array.ndim == 3: + return np.pad( + array, + ( + (left, right), + (top, bottom), + (0,0) + ), + constant_values=(np.nan, np.nan), + ) def traverse_dict_update(target, source): for k, v in source.items(): if isinstance(v, dict): - if k not in target: - target[k] = {} - traverse_dict_update(target[k], v) + try: + if k not in target: + target[k] = {} + traverse_dict_update(target[k], v) + except TypeError: + if k not in target.__dict__: + setattr(target, k, {}) + traverse_dict_update(target.__dict__[k], v) else: try: target[k] = v @@ -261,6 +321,7 @@ class PolarizationTrainer: target_delay=in_out_delay, xy_delay=xy_delay, drop_first=self.data_settings.drop_first, + drop_last=self.data_settings.drop_last, dtype=dtype, real=not dtype.is_complex, num_symbols=num_symbols, @@ -602,6 +663,7 @@ class RegenerationTrainer: console=None, checkpoint_path=None, settings_override=None, + new_model=False, reset_epoch=False, ): self.resume = checkpoint_path is not None @@ -615,12 +677,23 @@ class RegenerationTrainer: models.regenerator, torch.nn.utils.parametrizations.orthogonal, ]) + # self.new_model = True + self.model_name = datetime.now().strftime("%Y%m%d_%H%M%S") if self.resume: print(f"loading checkpoint from {checkpoint_path}") self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True) if settings_override is not None: traverse_dict_update(self.checkpoint_dict["settings"], settings_override) - if reset_epoch: + + if not new_model: + # self.new_model = False + checkpoint_file = checkpoint_path.split("/")[-1].split(".")[0] + if checkpoint_file.startswith("best"): + self.model_name = "_".join(checkpoint_file.split("_")[1:]) + else: + self.model_name = "_".join(checkpoint_file.split("_")[:-1]) + + if new_model or reset_epoch: self.checkpoint_dict["epoch"] = -1 self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"] @@ -654,7 +727,7 @@ class RegenerationTrainer: self.writer = None def setup_tb_writer(self, append=None): - log_dir = self.pytorch_settings.summary_dir + "/" + (datetime.now().strftime("%Y%m%d_%H%M%S")) + log_dir = self.pytorch_settings.summary_dir + "/" + self.model_name if append is not None: log_dir += "_" + str(append) @@ -697,8 +770,8 @@ class RegenerationTrainer: output_dim = self.model_settings.output_dim - # if self.data_settings.polarisations: - output_dim *= 2 + if self.data_settings.polarisations: + output_dim *= 2 dtype = getattr(torch, self.data_settings.dtype) @@ -755,11 +828,13 @@ class RegenerationTrainer: randomise_polarisations = self.data_settings.randomise_polarisations polarisations = self.data_settings.polarisations osnr = self.data_settings.osnr + # cross_pol_interference = self.data_settings.cross_pol_interference if override is not None: num_symbols = override.get("num_symbols", None) config_path = override.get("config_path", config_path) polarisations = override.get("polarisations", polarisations) randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations) + # cross_pol_interference = override.get("angle_var", 0) # get dataset dataset = FiberRegenerationDataset( file_path=config_path, @@ -768,11 +843,13 @@ class RegenerationTrainer: target_delay=in_out_delay, xy_delay=xy_delay, drop_first=self.data_settings.drop_first, + drop_last=self.data_settings.drop_last, dtype=dtype, real=not dtype.is_complex, num_symbols=num_symbols, randomise_polarisations=randomise_polarisations, polarisations=polarisations, + # cross_pol_interference=cross_pol_interference, osnr = osnr, ) @@ -842,8 +919,10 @@ class RegenerationTrainer: running_loss = 0.0 self.model.train() loader_len = len(train_loader) - x_key = "x_stacked"# if self.data_settings.polarisations else "x" - y_key = "y_stacked"# if self.data_settings.polarisations else "y" + # x_key = "x_stacked"# if self.data_settings.polarisations else "x" + # y_key = "y_stacked"# if self.data_settings.polarisations else "y" + x_key = "x" + y_key = "y" for batch_idx, batch in enumerate(train_loader): x = batch[x_key] y = batch[y_key] @@ -855,7 +934,10 @@ class RegenerationTrainer: angle.to(self.pytorch_settings.device), ) y_pred = self.model(x, -angle) + # loss = util.complexNN.complex_mse_loss(y_pred*torch.sqrt(torch.ones(1, device=y.device)*5), y*torch.sqrt(torch.ones(1, device=y.device)*5), power=True) loss = util.complexNN.complex_mse_loss(y_pred, y, power=True) + + loss_value = loss.item() loss.backward() optimizer.step() @@ -898,8 +980,10 @@ class RegenerationTrainer: self.model.eval() running_error = 0 - x_key = "x_stacked"# if self.data_settings.polarisations else "x" - y_key = "y_stacked"# if self.data_settings.polarisations else "y" + # x_key = "x_stacked"# if self.data_settings.polarisations else "x" + # y_key = "y_stacked"# if self.data_settings.polarisations else "y" + x_key = "x" + y_key = "y" with torch.no_grad(): for _, batch in enumerate(valid_loader): x = batch[x_key] @@ -911,7 +995,9 @@ class RegenerationTrainer: angle.to(self.pytorch_settings.device), ) y_pred = self.model(x, -angle) + # error = util.complexNN.complex_mse_loss(y_pred*torch.sqrt(torch.ones(1, device=y.device)*5), y*torch.sqrt(torch.ones(1, device=y.device)*5), power=True) error = util.complexNN.complex_mse_loss(y_pred, y, power=True) + error_value = error.item() running_error += error_value @@ -928,7 +1014,7 @@ class RegenerationTrainer: if (epoch + 1) % 10 == 0 or epoch < 10: # plotting is slow, so only do it every 10 epochs title_append, subtitle = self.build_title(epoch + 1) - head_fig, eye_fig, powers_fig = self.plot_model_response( + head_fig, eye_fig, weight_fig, powers_fig = self.plot_model_response( model=self.model, title_append=title_append, subtitle=subtitle, @@ -944,6 +1030,11 @@ class RegenerationTrainer: eye_fig, epoch + 1, ) + self.writer.add_figure( + "weights", + weight_fig, + epoch + 1, + ) self.writer.add_figure( "powers", @@ -967,9 +1058,10 @@ class RegenerationTrainer: regen = [] timestamps = [] angles = [] - x_key = "x_stacked"# if self.data_settings.polarisations else "x" - y_key = "y_stacked"# if self.data_settings.polarisations else "y" - + # x_key = "x_stacked"# if self.data_settings.polarisations else "x" + # y_key = "y_stacked"# if self.data_settings.polarisations else "y" + x_key = "x" + y_key = "y" with torch.no_grad(): model = model.to(self.pytorch_settings.device) for batch in loader: @@ -1056,7 +1148,7 @@ class RegenerationTrainer: ) title_append, subtitle = self.build_title(0) - head_fig, eye_fig, powers_fig = self.plot_model_response( + head_fig, eye_fig, weight_fig, powers_fig = self.plot_model_response( model=self.model, title_append=title_append, subtitle=subtitle, @@ -1072,6 +1164,11 @@ class RegenerationTrainer: eye_fig, 0, ) + self.writer.add_figure( + "weights", + weight_fig, + 0, + ) self.writer.add_figure( "powers", @@ -1103,6 +1200,9 @@ class RegenerationTrainer: train_loader, valid_loader = self.get_sliced_data() + # train_loader.dataset.fiber_out.to(self.pytorch_settings.device) + # train_loader.dataset.fiber_in.to(self.pytorch_settings.device) + optimizer_name = self.optimizer_settings.optimizer # lr = self.optimizer_settings.learning_rate @@ -1132,6 +1232,7 @@ class RegenerationTrainer: # except ValueError: # pass + self.early_stop_vals = {"min_loss": float("inf"), "plateau_cnt": 0} for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs): enable_progress = True if enable_progress: @@ -1147,29 +1248,64 @@ class RegenerationTrainer: epoch, enable_progress=enable_progress, ) + if self.early_stop(loss): + self.save_model_checkpoints(epoch, loss) + break if self.optimizer_settings.scheduler is not None: self.scheduler.step(loss) self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch) - if self.pytorch_settings.save_models and self.model is not None: - save_path = ( - Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar" - ) - save_path.parent.mkdir(parents=True, exist_ok=True) - checkpoint = self.build_checkpoint_dict(loss, epoch) - self.save_checkpoint(checkpoint, save_path) - - if loss < self.best["loss"]: - self.best = checkpoint - save_path = ( - Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar" - ) - save_path.parent.mkdir(parents=True, exist_ok=True) - self.save_checkpoint(self.best, save_path) + self.save_model_checkpoints(epoch, loss) self.writer.flush() + save_path = (Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar") + print(f"Training complete. Best checkpoint: {save_path}") self.writer.close() return self.best + def early_stop(self, loss): + # not stopping early at all + if not self.optimizer_settings.early_stopping: + return False + + # stopping because of abs threshold + if (loss_thr := self.optimizer_settings.early_stop_kwargs.get("threshold", None)) is not None: + if loss <= loss_thr: + print(f"Early stop: loss is below threshold ({loss:.2e} <= {loss_thr:.2e})") + return True + + # update vals + if loss < self.early_stop_vals["min_loss"]: + self.early_stop_vals["min_loss"] = loss + self.early_stop_vals["plateau_cnt"] = 0 + return False + + # stopping because of plateau + if (plateau_thresh := self.optimizer_settings.early_stop_kwargs.get("plateau", None)) is not None: + self.early_stop_vals["plateau_cnt"] += 1 + if self.early_stop_vals["plateau_cnt"] >= plateau_thresh: + print(f"Early stop: loss plateau length over threshold ({self.early_stop_vals["plateau_cnt"]} >= {plateau_thresh})") + return True + + # no stop + return False + + def save_model_checkpoints(self, epoch, loss): + if self.pytorch_settings.save_models and self.model is not None: + save_path = ( + Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar" + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + checkpoint = self.build_checkpoint_dict(loss, epoch) + self.save_checkpoint(checkpoint, save_path) + + if loss < self.best["loss"]: + self.best = checkpoint + save_path = ( + Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar" + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + self.save_checkpoint(self.best, save_path) + def _plot_model_response_powers(self, powers, layer_names, title_append="", subtitle="", show=True): powers = [power / powers[0] for power in powers] fig, ax = plt.subplots() @@ -1190,6 +1326,77 @@ class RegenerationTrainer: plt.show() return fig + def _plot_model_weights(self, model, title_append="", subtitle="", show=True): + model_params = [] + plots = [] + dims = [] + for num, (layer_name, layer) in enumerate(model.named_children()): + onn_weights = layer.ONN.weight + onn_weights = onn_weights.detach().cpu().numpy() + onn_values = np.abs(onn_weights).real + onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real + + model_params.append({layer_name: onn_weights}) + plots.append({layer_name: (num, onn_values, onn_angles)}) + dims.append(onn_weights.shape[0]) + + max_size = np.max(dims) + + for plot in plots: + layer_name, (num, onn_values, onn_angles) = plot.popitem() + + if num == 0: + value_img = onn_values + angle_img = onn_angles + onn_angles = pad_to_size(onn_angles, (max_size, None)) + onn_values = pad_to_size(onn_values, (max_size, None)) + else: + onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1)) + onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1)) + value_img = np.concatenate((value_img, onn_values), axis=1) + angle_img = np.concatenate((angle_img, onn_angles), axis=1) + + value_img = np.ma.array(value_img, mask=np.isnan(value_img)) + angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img)) + + fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(18, 6.5)) + fig.tight_layout() + + dividers = map(make_axes_locatable, axs) + caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers)) + + masked_value_img = value_img + cmap = cm.batlow + cmap.set_bad(color="#AAAAAA") + im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1) + fig.colorbar(im_val, cax=caxs[0], orientation="vertical") + + masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img) + cmap = cm.romaO + cmap.set_bad(color="#AAAAAA") + im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi) + cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)]) + cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", "2π"]) + + + axs[0].axis("off") + axs[1].axis("off") + + axs[0].set_title("Values") + axs[1].set_title("Angles") + + title = "Layer Weights" + if title_append: + title += f" {title_append}" + if subtitle: + title += f"\n{subtitle}" + fig.suptitle(title) + + + if show: + plt.show() + return fig + def _plot_model_response_eye( self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True ): @@ -1354,7 +1561,7 @@ class RegenerationTrainer: data_settings_backup = copy.deepcopy(self.data_settings) pytorch_settings_backup = copy.deepcopy(self.pytorch_settings) - self.data_settings.drop_first = 99.5 + random.randint(0, 1000) + self.data_settings.drop_first = int(64 + random.randint(0, 1000)) self.data_settings.shuffle = False self.data_settings.train_split = 1.0 self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols) @@ -1363,7 +1570,7 @@ class RegenerationTrainer: if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path ) - fiber_length = int(float(str(config_path).split("-")[4]) / 1000) + # fiber_length = int(float(str(config_path).split("-")[4]) / 1000) if not hasattr(self, "_plot_loader"): self._plot_loader, _ = self.get_sliced_data( override={ @@ -1376,6 +1583,7 @@ class RegenerationTrainer: } ) self._sps = self._plot_loader.dataset.samples_per_symbol + fiber_length = float(self._plot_loader.dataset.config["fiber"]["length"])/1000 self.data_settings = data_settings_backup self.pytorch_settings = pytorch_settings_backup @@ -1403,7 +1611,7 @@ class RegenerationTrainer: import gc head_fig = self._plot_model_response_head( - fiber_out_rot[: self.pytorch_settings.head_symbols * self._sps], + fiber_out[: self.pytorch_settings.head_symbols * self._sps], fiber_in[: self.pytorch_settings.head_symbols * self._sps], regen[: self.pytorch_settings.head_symbols * self._sps], angles[: self.pytorch_settings.head_symbols * self._sps], @@ -1417,7 +1625,7 @@ class RegenerationTrainer: # raise NotImplementedError("Eye diagram not implemented") eye_fig = self._plot_model_response_eye( fiber_in[: self.pytorch_settings.eye_symbols * self._sps], - fiber_out_rot[: self.pytorch_settings.eye_symbols * self._sps], + fiber_out[: self.pytorch_settings.eye_symbols * self._sps], regen[: self.pytorch_settings.eye_symbols * self._sps], timestamps=timestamps[: self.pytorch_settings.eye_symbols * self._sps], labels=("fiber in", "fiber out", "regen"), @@ -1426,9 +1634,11 @@ class RegenerationTrainer: subtitle=subtitle, show=show, ) + + weight_fig = self._plot_model_weights(model, title_append=title_append, subtitle=subtitle, show=show) gc.collect() - return head_fig, eye_fig, power_fig + return head_fig, eye_fig, weight_fig, power_fig def build_title(self, number: int): title_append = f"epoch {number}" diff --git a/src/single-core-regen/plot_model.py b/src/single-core-regen/plot_model.py index 27eab39..06b6cf3 100644 --- a/src/single-core-regen/plot_model.py +++ b/src/single-core-regen/plot_model.py @@ -1,4 +1,6 @@ -import os +from pathlib import Path +import sys + from matplotlib import pyplot as plt import numpy as np import torch @@ -25,7 +27,29 @@ from hypertraining import models # ), # constant_values=(-np.inf, -np.inf), # ) - + +def register_puccs_cmap(puccs_path=None): + puccs_path = Path(__file__).resolve().parent / 'puccs.csv' if puccs_path is None else puccs_path + + colors = [] + # keys = None + with open(puccs_path, "r") as f: + for i, line in enumerate(f.readlines()): + elements = tuple(line.split(",")) + # if i == 0: + # # keys = elements + # continue + # else: + try: + colors.append(tuple(map(float, elements[4:]))) + except ValueError: + continue + # colors = [] + # for current in puccs_csv_data: + # colors.append(tuple(current[4:])) + from matplotlib.colors import LinearSegmentedColormap + import matplotlib as mpl + mpl.colormaps.register(LinearSegmentedColormap.from_list('puccs', colors)) def pad_to_size(array, size): if not hasattr(size, "__len__"): @@ -65,7 +89,7 @@ def pad_to_size(array, size): constant_values=(np.nan, np.nan), ) -def model_plot(model_path): +def model_plot(model_path, show=True): torch.serialization.add_safe_globals([ *util.complexNN.__all__, GlobalSettings, @@ -81,173 +105,113 @@ def model_plot(model_path): dims = checkpoint_dict["model_kwargs"].pop("dims") model = models.regenerator(*dims, **checkpoint_dict["model_kwargs"]) - model.load_state_dict(checkpoint_dict["model_state_dict"]) + model.load_state_dict(checkpoint_dict["model_state_dict"], strict=False) model_params = [] plots = [] max_size = np.max(dims) # max_act_size = np.max(dims[1:]) - angles = [None, None] - weights = [None, None] + # angles = [None, None] + # weights = [None, None] for num, (layer_name, layer) in enumerate(model.named_children()): # each layer contains an "ONN" layer and an "activation" layer # activation layer is approximately the same for all layers and nodes -> rotation by 90 degrees - onn_weights = layer.ONN.weight.T + onn_weights = layer.ONN.weight onn_weights = onn_weights.detach().cpu().numpy() onn_values = np.abs(onn_weights).real onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real - - - act = layer.activation - - act_values = np.ones((act.size, 1)) - - act_values = np.nan * act_values - - act_angles = act.phase.unsqueeze(-1).detach().cpu().numpy() - ... - # act_phi_bias = torch.pi * act.V_bias / (act.V_pi + 1e-8) - # act_phi_gain = torch.pi * (act.alpha * act.gain * act.responsivity) / (act.V_pi + 1e-8) - # xs = (0.01, 0.1, 1) - - # act_values = np.zeros((act.size, len(xs)*2)) - # act_angles = np.zeros((act.size, len(xs)*2)) - - # act_values[:,:] = np.nan - # act_angles[:,:] = np.nan - - # for xi, x in enumerate(xs): - # phi_intermediate = act_phi_gain * x**2 + act_phi_bias - - # act_resulting_gain = ( - # 1j - # * torch.sqrt(1-act.alpha) - # * torch.exp(-0.5j * phi_intermediate) - # * torch.cos(0.5 * phi_intermediate) - # * x - # ) - - # act_resulting_gain = act_resulting_gain.detach().cpu().numpy() - # act_values[:, xi*2] = np.abs(act_resulting_gain).real - # act_angles[:, xi*2] = np.mod(np.angle(act_resulting_gain), 2*np.pi).real - - - - # if angles[0] is None or angles[0] > np.min(onn_angles.flatten()): - # angles[0] = np.min(onn_angles.flatten()) - # if angles[1] is None or angles[1] < np.max(onn_angles.flatten()): - # angles[1] = np.max(onn_angles.flatten()) - # if weights[0] is None or weights[0] > np.min(onn_weights.flatten()): - # weights[0] = np.min(onn_weights.flatten()) - # if weights[1] is None or weights[1] < np.max(onn_weights.flatten()): - # weights[1] = np.max(onn_weights.flatten()) model_params.append({layer_name: onn_weights}) - plots.append({layer_name: (num, onn_values, onn_angles, act_values, act_angles)}) + plots.append({layer_name: (num, onn_values, onn_angles)})#, act_values, act_angles)}) # fig, axs = plt.subplots(3, len(model_params)*2-1, figsize=(20, 5)) for plot in plots: - layer_name, (num, onn_values, onn_angles, act_values, act_angles) = plot.popitem() - # for_plot[:, :, 0] = (for_plot[:, :, 0] - angles[0]) / (angles[1] - angles[0]) - # for_plot[:, :, 1] = (for_plot[:, :, 1] - weights[0]) / (weights[1] - weights[0]) - - onn_values = np.ma.array(onn_values, mask=np.isnan(onn_values)) - onn_values = onn_values - np.min(onn_values) - onn_values = onn_values / np.max(onn_values) - - act_values = np.ma.array(act_values, mask=np.isnan(act_values)) - act_values = act_values - np.min(act_values) - act_values = act_values / np.max(act_values) - - - onn_values = onn_values - onn_values = pad_to_size(onn_values, (max_size, None)) - - act_values = act_values - act_values = pad_to_size(act_values, (max_size, 3)) - - onn_angles = onn_angles / np.pi - onn_angles = pad_to_size(onn_angles, (max_size, None)) - - act_angles = act_angles / np.pi - act_angles = pad_to_size(act_angles, (max_size, 3)) - - - - # onn_angles = onn_angles - np.min(onn_angles) - # onn_angles = onn_angles / np.max(onn_angles) - - # act_angles = act_angles - np.min(act_angles) - # act_angles = act_angles / np.max(act_angles) + layer_name, (num, onn_values, onn_angles) = plot.popitem() if num == 0: - value_img = np.concatenate((onn_values, act_values), axis=1) - angle_img = np.concatenate((onn_angles, act_angles), axis=1) + value_img = onn_values + angle_img = onn_angles + onn_angles = pad_to_size(onn_angles, (max_size, None)) + onn_values = pad_to_size(onn_values, (max_size, None)) else: - value_img = np.concatenate((value_img, onn_values, act_values), axis=1) - angle_img = np.concatenate((angle_img, onn_angles, act_angles), axis=1) - + onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1)) + onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1)) + value_img = np.concatenate((value_img, onn_values), axis=1) + angle_img = np.concatenate((angle_img, onn_angles), axis=1) - + value_img = np.ma.array(value_img, mask=np.isnan(value_img)) + angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img)) - # -np.inf to np.nan - # value_img[value_img == -np.inf] = np.nan - - # angle_img += move_to_location_in_size(onn_angles, ((max_size+3)*num, 0), img_overall_size) - # angle_img += move_to_location_in_size(act_angles, ((max_size+3)*(num+1) + 2, 0), img_overall_size) - - + # from cmcrameri import cm + from cmap import Colormap as cm + import scicomap as sc + # from matplotlib import colors as mcolors + # alpha_map = mcolors.LinearSegmentedColormap( + # 'alphamap', + # { + # 'red': [(0, 0, 0), (1, 0, 0)], + # 'green': [(0, 0, 0), (1, 0, 0)], + # 'blue': [(0, 0, 0), (1, 0, 0)], + # 'alpha': [ + # (0, 1, 1), + # # (0.2, 0.2, 0.1), + # (1, 0, 0) + # ] + # } + # ) + # alpha_map.set_bad(color="#AAAAAA") - from cmcrameri import cm - from matplotlib import colors as mcolors - alpha_map = mcolors.LinearSegmentedColormap( - 'alphamap', - { - 'red': [(0, 0, 0), (1, 0, 0)], - 'green': [(0, 0, 0), (1, 0, 0)], - 'blue': [(0, 0, 0), (1, 0, 0)], - 'alpha': [ - (0, 1, 1), - # (0.2, 0.2, 0.1), - (1, 0, 0) - ] - } - ) - alpha_map.set_bad(color="#AAAAAA") - - - fig, axs = plt.subplots(3, 1, sharex=True, sharey=True, figsize=(7, 8.5)) - fig.tight_layout() + from mpl_toolkits.axes_grid1 import make_axes_locatable + + fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(16, 5)) + # fig.tight_layout() + dividers = map(make_axes_locatable, axs) + caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers)) # masked_value_img = np.ma.masked_where(np.isnan(value_img), value_img) masked_value_img = value_img - cmap = cm.batlowW + cmap = cm('google:turbo').to_matplotlib() + # cmap = sc.ScicoSequential("rainbow").get_mpl_color_map() cmap.set_bad(color="#AAAAAA") - im_val = axs[0].imshow(masked_value_img, cmap=cmap) + im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1) + fig.colorbar(im_val, cax=caxs[0], orientation="vertical") + masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img) - cmap = cm.romaO + # cmap = cm('crameri:romao').to_matplotlib() + # cmap = plt.get_cmap('puccs') + # cmap = sc.ScicoCircular("colorwheel").get_mpl_color_map() + cmap = cm('colorcet:CET_C8').to_matplotlib() cmap.set_bad(color="#AAAAAA") - im_ang = axs[1].imshow(masked_angle_img, cmap=cmap) - im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap) - im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map) + im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi) + cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)]) + cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", "2π"]) + # im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap) + # im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map) axs[0].axis("off") axs[1].axis("off") - axs[2].axis("off") + # axs[2].axis("off") axs[0].set_title("Values") axs[1].set_title("Angles") - axs[2].set_title("Values and Angles") + # axs[2].set_title("Values and Angles") ... - plt.show() + if show: + plt.show() + return fig + # model = models.regenerator(*dims, **model_kwargs) - if __name__ == "__main__": - model_plot(".models/best_20250105_145719.tar") + register_puccs_cmap() + if len(sys.argv) > 1: + model_plot(sys.argv[1]) + else: + print("Please provide a model path as an argument") + # model_plot(".models/best_20250114_224234.tar") diff --git a/src/single-core-regen/puccs.csv b/src/single-core-regen/puccs.csv new file mode 100644 index 0000000..861dd09 --- /dev/null +++ b/src/single-core-regen/puccs.csv @@ -0,0 +1,102 @@ +"x","L","a","b","R","G","B" +0.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0. +0.01,0.5374499525557803,0.604014067614707,0.6777967519386492,0.8956274406155226,0.27553288030331824,0. +0.02,0.5560867887452998,0.5680836759482211,0.6855816828789898,0.9019507507843885,0.318608215541461,0. +0.03,0.5746877595125583,0.5322224300667823,0.6933516322080414,0.907905487190649,0.3580633000693721,0. +0.04,0.5932314662487472,0.49647158484797804,0.7010976613543587,0.9134808162089558,0.3949845524063657,0. +0.05,0.6117000836392819,0.46086550613202343,0.7088123243737041,0.918668356138916,0.43002019316005363,0. +0.06,0.6300828534995973,0.4254249348741487,0.7164911273850869,0.923462736751354,0.4635961938811463,0. +0.07,0.6483763163456417,0.3901565406944371,0.7241326253017896,0.9278609626724071,0.49601354353255284,0. +0.08,0.6665840140182806,0.3550534951951814,0.7317382976124045,0.9318616057744784,0.5274983630587982,0. +0.09,0.6847162776119433,0.3200958808181962,0.7393124597949372,0.9354640163365924,0.5582303922647159,0. +0.1,0.7027902128942014,0.2852507189547545,0.7468622572263107,0.9386675557407496,0.5883604892249517,0.004034952213848706 +0.11,0.7208298719332069,0.25047163906104203,0.7543977368741345,0.9414708123927996,0.6180221032545026,0.016031521294251994 +0.12,0.7388665670611175,0.2156982733607376,0.7619319784446927,0.943870754968487,0.6473392272576862,0.029857267582036696 +0.13,0.7569392765472108,0.18085547473834482,0.7694812638396673,0.9458617774020323,0.676432172396153,0.045365670193636125 +0.14,0.7750950944867471,0.14585244938794778,0.7770652650825484,0.9474345911958609,0.7054219201084561,0.06017985923530026 +0.15,0.793389684293558,0.11058188251425949,0.7847072337503834,0.9485749196617762,0.7344334940032564,0.07418869502646075 +0.16,0.8117919447684838,0.07510373484536464,0.792394178330817,0.9492596163836376,0.7634480277996188,0.08767517868137237 +0.17,0.8293050962981561,0.03629277424762101,0.799038155466063,0.9462308253550155,0.7922009241807345,0.10066327128139077 +0.18,0.8213303100752708,-0.0062517290795987,0.7879999288492758,0.9088702681901394,0.7940579017644396,0.10139639009534024 +0.19,0.8134831311534617,-0.048115463155645855,0.7771383286984362,0.8716809050191757,0.7954897210083888,0.10232311621802098 +0.2,0.80558613530069,-0.0902449644291895,0.7662077749032042,0.8337524177888596,0.7965471523787845,0.10344968926026826 +0.21,0.7975860185564765,-0.13292460297117392,0.7551344872795225,0.7947193410849823,0.7972381033243311,0.10477682283894393 +0.22,0.7894147026971006,-0.17651756772919341,0.7438242359834689,0.7540941866826836,0.7975605026647324,0.10631182441371936 +0.23,0.7809997374598548,-0.2214103719409295,0.7321767396537806,0.7112894518675287,0.7974995317311054,0.1080672415170634 +0.24,0.7722646970273015,-0.2680107379394189,0.7200862142018722,0.6655745739336695,0.7970267795229349,0.11006041388465265 +0.25,0.7631307298557146,-0.3167393290089981,0.7074435179925446,0.6160047476007512,0.7960993904970947,0.11231257117602686 +0.26,0.7535192192483822,-0.36801555555407994,0.6941398344519211,0.5612859274945571,0.794659599537827,0.11484733363789801 +0.27,0.7433557597838075,-0.42223636134393283,0.6800721760037781,0.4994862901720824,0.7926351396848288,0.11768844813479104 +0.28,0.732575139048096,-0.479749646583324,0.6651502794883674,0.42731393423789277,0.7899410218414098,0.12085678487511567 +0.29,0.7211269294461059,-0.5408244362880141,0.6493043460161184,0.3378265607222193,0.786483110019224,0.124366774034814 +0.3,0.7090756028785993,-0.6051167807996883,0.6326236137723747,0.2098475715121697,0.7821998608677176,0.12819222127525928 +0.31,0.7094510768540225,-0.6165036055456403,0.5630307498747129,0.15061488620640032,0.7845112116922692,0.21943537230975235 +0.32,0.7174669421288304,-0.5917687864932311,0.4797229624661701,0.18766933782916642,0.7905828987725732,0.31091344246312086 +0.33,0.7249009746435938,-0.5688293479200438,0.40246208306061504,0.21160609617940718,0.7962175427587832,0.38519766326885596 +0.34,0.7317072855135611,-0.5478268906666535,0.3317250285377912,0.22717569971119178,0.8013847719431052,0.4490960048955565 +0.35,0.7379328517830899,-0.5286164561226088,0.26702357292455026,0.23690087622812972,0.8061220291668977,0.5056371468159843 +0.36,0.7436229063122554,-0.5110584677642499,0.20788761731555405,0.24226377668817778,0.8104638164122776,0.5563570758573497 +0.37,0.7488251728809415,-0.4950056627547577,0.15382117501783654,0.24424372086048424,0.8144455902164638,0.6022301663745243 +0.38,0.7535943992285348,-0.48028910419451787,0.10425526029155024,0.24352232677523483,0.818107753931944,0.6440238320299774 +0.39,0.757994865186593,-0.4667104416936734,0.05852182167144754,0.240562414747303,0.8214980148949816,0.6824536572462205 +0.4,0.7620994844391137,-0.4540446830999986,0.015863077249098356,0.2356325204239052,0.8246710357361025,0.7182393675419642 +0.41,0.7659871096124125,-0.4420485102716773,-0.024540477496154123,0.22880568593963535,0.8276865975886148,0.7521146815529202 +0.42,0.7697410958994951,-0.4304647113488041,-0.06355514164248566,0.21993360985514526,0.8306086550266585,0.7848331944479765 +0.43,0.773446484628189,-0.4190308715098135,-0.10206473803580057,0.20858849290850018,0.833503273690861,0.8171544357676854 +0.44,0.7771893686864673,-0.4074813310994203,-0.14096401824224686,0.1939295692427068,0.8364382500400466,0.8498448067259188 +0.45,0.7810574093604746,-0.3955455908045306,-0.18116403397486242,0.17438366103820427,0.839483669055626,0.8836865023336339 +0.46,0.7851360804917298,-0.3829599011818591,-0.2235531031349741,0.14679145002531463,0.8427091517444469,0.9194481212717681 +0.47,0.789525027020907,-0.369416784561489,-0.26916682191206776,0.10278921007810798,0.8461971304126237,0.9580316568065935 +0.48,0.7942371698732826,-0.35487637041943493,-0.3181394757087982,0.0013920913109500188,0.8499626968466341,0.9995866371771526 +0.49,0.7773897680996302,-0.31852357140025195,-0.34537976514700053,0.10740420703601522,0.8254781216972907,1. +0.5,0.7604011244310231,-0.28211213216592784,-0.3722846952738428,0.1581725581872408,0.8008522647497104,1. +0.51,0.7433440454962605,-0.2455540169176899,-0.3992980063927199,0.19300141807932156,0.7761561224913385,1. +0.52,0.7262590833969331,-0.20893614020926626,-0.42635547610418184,0.2194621842292243,0.751443124097109,1. +0.53,0.709058602701224,-0.17207067467417486,-0.453595892719742,0.2405673704012788,0.7265803324554873,1. +0.54,0.6915768892539101,-0.1346024482921609,-0.48128169789479536,0.25788347992973676,0.701321051230534,1. +0.55,0.6736331627810209,-0.09614399811510127,-0.5096991935104321,0.2722888922216317,0.6753950894563805,1. +0.56,0.6551463184003872,-0.05652149358027936,-0.5389768254408652,0.28422807900785235,0.6486730893521468,1. +0.57,0.6361671326276888,-0.01584376303510615,-0.5690341788729347,0.293907374075009,0.6212117649042732,1. +0.58,0.6168396823565967,0.025580396234342995,-0.5996430791016598,0.301442767979156,0.5931976878638505,1. +0.59,0.5973210287815495,0.06741435793529688,-0.6305547881733555,0.30694603901024253,0.5648312189065924,1. +0.6,0.5777303704171711,0.10940264614179468,-0.661580531294122,0.3105418468883679,0.5362525958007331,1. +0.61,0.5581475370499237,0.15137416317967575,-0.6925938819599547,0.3123531986526998,0.5075386530652202,1. +0.62,0.5386227795100639,0.19322120739317136,-0.7235152578861672,0.31248922600720636,0.4787151440558522,1. +0.63,0.5191666876024412,0.23492108185347996,-0.754327887989376,0.31103663081260624,0.44973844514160927,1. +0.64,0.4996990584326256,0.2766456839100268,-0.7851587896650079,0.30803814950244496,0.4204116611935119,1. +0.65,0.479957679121191,0.3189570094767831,-0.8164232296840259,0.30343473603466015,0.390226489453496,1. +0.66,0.4600072725872886,0.3617163391430824,-0.8480187063016573,0.29717122075330515,0.3591178757512998,1. +0.67,0.44600100870220305,0.4113853615984094,-0.8697728377551008,0.3178994129506999,0.3295740682997879,1. +0.68,0.4574651571354146,0.44026390446569547,-0.8504539292487465,0.3842479358768364,0.3280946443367561,1. +0.69,0.4691809168948424,0.46977626401045774,-0.830711015748157,0.44293649140770447,0.3260767554252525,1. +0.7,0.4811696900083858,0.49997635259991063,-0.8105080314416201,0.49708450874457527,0.3234487047238236,1. +0.71,0.49350094811609174,0.5310391714342613,-0.7897279055963483,0.5485591109413528,0.3201099534066949,1. +0.72,0.5062548753068121,0.5631667067020758,-0.7682355153041539,0.5985798481027601,0.3159263917472715,1. +0.73,0.5195243020949684,0.5965928013272943,-0.7458744264238399,0.6480500606439057,0.31071717884730565,1. +0.74,0.5334043922713477,0.6315571758288618,-0.7224842728734379,0.6976685401842261,0.3042411890803418,1. +0.75,0.5479805812358602,0.6682750446095802,-0.697921082452685,0.7479712773579563,0.29618040787504757,1. +0.76,0.5633244502526606,0.7069267230777347,-0.6720642293775535,0.7993701361353484,0.28611136999256687,1. +0.77,0.5794956601139,0.7476624986056212,-0.6448131757501174,0.8521918014427678,0.2734527325942473,1. +0.78,0.5965429098573916,0.7906050455688622,-0.6160858559672187,0.9067003897516911,0.2573693489198746,1. +0.79,0.6145761476424179,0.8360313267658297,-0.5856969899409387,0.963334644317004,0.23648492980159264,1. +0.8,0.6232910688128902,0.859291371252556,-0.5300995185388214,1.,0.21867949406239662,0.9712088595948508 +0.81,0.6159984336377875,0.8439887543380684,-0.44635440435952856,1.,0.21606849746358275,0.9041480210597966 +0.82,0.6091642745073532,0.8296481879180277,-0.36787420852419694,1.,0.21421830096504035,0.8419706002336461 +0.83,0.6025478038652375,0.8157644115969636,-0.2918938425681935,1.,0.21295365915197917,0.7823908751330636 +0.84,0.5961857222953111,0.8024144366282877,-0.21883475834162458,0.9971140114799418,0.21220068235083267,0.7256713129328118 +0.85,0.5900921771070883,0.7896279492437488,-0.1488594167412921,0.993273906363258,0.2118788857127918,0.671860243327784 +0.86,0.5842771639541229,0.7774259239818333,-0.08208260304413262,0.9887084084529413,0.21191070453347688,0.6209624706933893 +0.87,0.578741582584259,0.7658102488427286,-0.018514649521559012,0.9835846378805114,0.2122246941077346,0.5728987835613306 +0.88,0.5734741590353537,0.7547572669288056,0.04197390858426542,0.9780378159372328,0.21275878699579343,0.5274829957183049 +0.89,0.5684517008574971,0.7442183119942206,0.09964940221121898,0.9721670725313721,0.21346242315895625,0.4844270603851604 +0.9,0.5636419856510335,0.7341257696545772,0.15488185789614228,0.9660363209686843,0.21429691147008262,0.4433660148378527 +0.91,0.5590069340453534,0.7243997354573974,0.20810856081277884,0.9596781387247791,0.2152344151262528,0.4038812338146013 +0.92,0.5545051525321143,0.7149533506766244,0.25980485409830323,0.9530986696850675,0.21625626438013962,0.3655130449917989 +0.93,0.5500961975299247,0.705701749880514,0.3104351723857584,0.9462863346513658,0.21735046958786286,0.327780364198278 +0.94,0.545740378056064,0.6965616468647046,0.36045530782708896,0.93921469089265,0.21851014470332586,0.29014917175372823 +0.95,0.5414004092067859,0.6874548042588865,0.41029342232076466,0.9318478255642132,0.21973168075163751,0.2519897371806688 +0.96,0.5370416605957644,0.6783085548415655,0.46034719456417006,0.9241434776436454,0.22101341980094052,0.2124579038400577 +0.97,0.5326309593934517,0.6690532898786764,0.5109975653738162,0.9160532016485884,0.22235495330179011,0.17018252385769012 +0.98,0.5281374148557197,0.6596241892863608,0.5625992691950712,0.90752576202319,0.22375597459867458,0.1223073280126531 +0.99,0.5235317096396147,0.6499597345521199,0.615488972291106,0.8985077346125597,0.22521565729028564,0.05933950582860665 +1.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0. diff --git a/src/single-core-regen/regen_no_hyper.py b/src/single-core-regen/regen_no_hyper.py index d68feeb..1450bc7 100644 --- a/src/single-core-regen/regen_no_hyper.py +++ b/src/single-core-regen/regen_no_hyper.py @@ -26,28 +26,39 @@ global_settings = GlobalSettings( ) data_settings = DataSettings( - # config_path="data/*-128-16384-1-0-0-0-0-PAM4-0-0.ini", - config_path="data/20250110-190528-128-16384-100000-0-0.2-17.0-0.058-PAM4-0-0.14-10.ini", - # config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)], + # config_path="data/20250115-114319-128-16384-inf-100000-0-0-0-0-PAM4-1-0-10.ini", # no effects enabled - baseline + # config_path = "data/20250115-233553-128-16384-1060.0-100000-0-0.2-17.0-0.058-PAM4-1.0-0.0-10.ini", # dispersion + slope only + # config_path="data/20250115-115836-128-16384-60.0-100000-0-0.2-17-0.058-PAM4-1000-0.2-10.ini", # all linear effects enabled with realistic values + noise + pmd (delta_beta=0.2) + ortho_error = 0.1 + # config_path="data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # a) + # config_path="data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # b) + # config_path="data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # c) + # config_path="data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # d) 10ps dgd + config_path="data/20250120-105720-128-16384-inf-100000-0-0.2-17-0.058-PAM4-0-0-10.ini", # d) 10ns + + # config_path="data/20250114-215547-128-16384-60.0-100000-1.15-0.2-17-0.058-PAM4-1-0-10.ini", # with gamma=1.15, 2.5dBm launch power, no pmd + + dtype="complex64", # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber - symbols=4, # study: single_core_regen_20241123_011232 + symbols=4, # study: single_core_regen_20241123_011232 -> taps spread over 4 symbols @ 10GBd # output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y)) - output_size=20, # study: single_core_regen_20241123_011232 (model_input_dim/2) - shuffle=True, - drop_first=64, + output_size=20, # = model_input_dim/2 study: single_core_regen_20241123_011232 + shuffle=False, + drop_first=256, + drop_last=256, train_split=0.8, randomise_polarisations=False, - polarisations=True, + polarisations=False, + # cross_pol_interference=0.01, osnr=16, #16dB due to amplification with NF 5 ) pytorch_settings = PytorchSettings( epochs=1000, - batchsize=2**14, + batchsize=2**13, device="cuda", - dataloader_workers=24, - dataloader_prefetch=8, + dataloader_workers=32, + dataloader_prefetch=4, summary_dir=".runs", write_every=2**5, save_models=True, @@ -65,16 +76,13 @@ model_settings = ModelSettings( # "n_hidden_nodes_3": 4, # "n_hidden_nodes_4": 2, }, - model_activation_func="phase_shift", + model_activation_func="EOActivation", dropout_prob=0, model_layer_function="ONNRect", model_layer_kwargs={"square": True}, scale=2.0, model_layer_parametrizations=[ - { - "tensor_name": "weight", - "parametrization": util.complexNN.energy_conserving, - }, + # EOactivation { "tensor_name": "alpha", "parametrization": util.complexNN.clamp, @@ -83,54 +91,20 @@ model_settings = ModelSettings( "max": 1, }, }, + # ONNRect { - "tensor_name": "gain", - "parametrization": util.complexNN.clamp, - "kwargs": { - "min": 0, - "max": None, - }, - }, - { - "tensor_name": "phase_bias", - "parametrization": util.complexNN.clamp, - "kwargs": { - "min": 0, - "max": 2 * torch.pi, - }, + "tensor_name": "weight", + "parametrization": torch.nn.utils.parametrizations.orthogonal, }, + # Scale { "tensor_name": "scale", "parametrization": util.complexNN.clamp, "kwargs": { "min": 0, - "max": 2, + "max": 10, }, - }, - { - "tensor_name": "angle", - "parametrization": util.complexNN.clamp, - "kwargs": { - "min": -torch.pi, - "max": torch.pi, - }, - }, - # { - # "tensor_name": "scale", - # "parametrization": util.complexNN.clamp, - # }, - # { - # "tensor_name": "bias", - # "parametrization": util.complexNN.clamp, - # }, - # { - # "tensor_name": "V", - # "parametrization": torch.nn.utils.parametrizations.orthogonal, - # }, - { - "tensor_name": "loss", - "parametrization": util.complexNN.clamp, - }, + } ], ) @@ -145,191 +119,35 @@ optimizer_settings = OptimizerSettings( scheduler="ReduceLROnPlateau", scheduler_kwargs={ "patience": 2**6, - "factor": 0.75, + "factor": 0.5, # "threshold": 1e-3, "min_lr": 1e-6, "cooldown": 10, }, + early_stopping=True, + early_stop_kwargs={ + "threshold": 1e-06, + "plateau": 2**7, + } ) - -def save_dict_to_file(dictionary, filename): - """ - Save the best dictionary to a JSON file. - - :param best: Dictionary containing the best training results. - :type best: dict - :param filename: Path to the JSON file where the dictionary will be saved. - :type filename: str - """ - with open(filename, "w") as f: - json.dump(dictionary, f, indent=4) - - -def sweep_lengths(*lengths, model=None, data_glob: str = None, strategy="newest"): - assert model is not None, "Model must be provided." - assert data_glob is not None, "Data glob must be provided." - model = model - - fiber_ins = {} - fiber_outs = {} - regens = {} - timestampss = {} - - trainer = RegenerationTrainer( - checkpoint_path=model, - ) - trainer.define_model() - - for length in lengths: - data_glob_length = data_glob.replace("{length}", str(length)) - files = list(Path.cwd().glob(data_glob_length)) - if len(files) == 0: - continue - if strategy == "newest": - sorted_kwargs = { - "key": lambda x: x.stat().st_mtime, - "reverse": True, - } - elif strategy == "oldest": - sorted_kwargs = { - "key": lambda x: x.stat().st_mtime, - "reverse": False, - } - else: - raise ValueError(f"Unknown strategy {strategy}.") - file = sorted(files, **sorted_kwargs)[0] - - loader, _ = trainer.get_sliced_data(override={"config_path": file}) - fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader) - - fiber_ins[length] = fiber_in - fiber_outs[length] = fiber_out - regens[length] = regen - timestampss[length] = timestamps - - data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0]) - channel_names = ["" for _ in range(2 * len(timestampss.keys()) + 2)] - - data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128 - data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square() - - channel_names[1] = "fiber in x" - - for li, length in enumerate(timestampss.keys()): - data[2 + 2 * li, 0, :] = timestampss[length] / 128 - data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square() - data[2 + 2 * li + 1, 0, :] = timestampss[length] / 128 - data[2 + 2 * li + 1, 1, :] = regens[length][:, 0].abs().square() - - channel_names[2 + 2 * li + 1] = f"regen x {length}" - channel_names[2 + 2 * li] = f"fiber out x {length}" - - # get current backend - backend = matplotlib.get_backend() - - matplotlib.use("TkCairo") - eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names) - - print_attrs = ("channel_name", "success", "min_area") - with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}): - for result in eye.eye_stats: - print_dict = {attr: result[attr] for attr in print_attrs} - rprint(print_dict) - rprint() - - eye.plot(all_stats=False) - matplotlib.use(backend) - - if __name__ == "__main__": - # lengths = range(90000, 100000+10000, 10000) - # lengths = [100000] - # sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest") - + trainer = RegenerationTrainer( global_settings=global_settings, data_settings=data_settings, pytorch_settings=pytorch_settings, model_settings=model_settings, optimizer_settings=optimizer_settings, - # checkpoint_path=".models/best_20250104_191428.tar", - reset_epoch=True, - # settings_override={ - # "data_settings": { - # "config_path": "data/20241229-163*-128-16384-100000-*.ini", - # "polarisations": True, - # }, - # "model_settings": { - # "scale": 2.0, - # } - # } + checkpoint_path=".models/best_20250117_144001.tar", + new_model=True, + settings_override={ + "data_settings": data_settings.__dict__, # "optimizer_settings": { - # "optimizer_kwargs": { - # "lr": 0.01, - # }, + # "early_stop_kwargs":{ + # "plateau": 2**8, + # } # } - # } - # 20241202_143149 + } ) trainer.train() - - # from hypertraining.lighning_models import regenerator, regeneratorData - # import lightning as L - - # model = regenerator( - # 2 * data_settings.output_size, - # *model_settings.overrides["hidden_layer_dims"], - # model_settings.output_dim, - # layer_function=getattr(util.complexNN, model_settings.model_layer_function), - # layer_func_kwargs=model_settings.model_layer_kwargs, - # act_function=getattr(util.complexNN, model_settings.model_activation_func), - # act_func_kwargs=None, - # parametrizations=model_settings.model_layer_parametrizations, - # dtype=getattr(torch, data_settings.dtype), - # dropout_prob=model_settings.dropout_prob, - # scale_layers=model_settings.scale, - # optimizer=getattr(torch.optim, optimizer_settings.optimizer), - # optimizer_kwargs=optimizer_settings.optimizer_kwargs, - # lr_scheduler=getattr(torch.optim.lr_scheduler, optimizer_settings.scheduler), - # lr_scheduler_kwargs=optimizer_settings.scheduler_kwargs, - # ) - - # dm = regeneratorData( - # config_globs=data_settings.config_path, - # output_symbols=data_settings.symbols, - # output_dim=data_settings.output_size, - # dtype=getattr(torch, data_settings.dtype), - # drop_first=data_settings.drop_first, - # shuffle=data_settings.shuffle, - # train_split=data_settings.train_split, - # batch_size=pytorch_settings.batchsize, - # loader_settings={ - # "num_workers": pytorch_settings.dataloader_workers, - # "prefetch_factor": pytorch_settings.dataloader_prefetch, - # "pin_memory": True, - # "drop_last": True, - # }, - # seed=global_settings.seed, - # ) - - # # writer = L.SummaryWriter(pytorch_settings.summary_dir + f"/{datetime.now().strftime('%Y%m%d_%H%M%S')}") - - # # from torch.utils.tensorboard import SummaryWriter - - # subdir = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}" - - # # writer = SummaryWriter(pytorch_settings.summary_dir + f"/{subdir}") - - # logger = L.pytorch.loggers.TensorBoardLogger(pytorch_settings.summary_dir, name=subdir, log_graph=True) - - # trainer = L.Trainer( - # fast_dev_run=False, - # # max_epochs=pytorch_settings.epochs, - # max_epochs=2, - # enable_checkpointing=True, - # default_root_dir=f".models/{subdir}/", - # logger=logger, - # ) - - # trainer.fit(model, dm) diff --git a/src/single-core-regen/signal_gen/generate_signal.py b/src/single-core-regen/signal_gen/generate_signal.py index d58b28f..4f8e90f 100644 --- a/src/single-core-regen/signal_gen/generate_signal.py +++ b/src/single-core-regen/signal_gen/generate_signal.py @@ -12,6 +12,7 @@ Full license text in LICENSE file """ import configparser +# import copy from datetime import datetime import hashlib from pathlib import Path @@ -40,7 +41,7 @@ alpha = 0.2 D = 17 S = 0.058 bireflength = 10 -max_delta_beta = 0.14 +pmd_q = 0.2 ; birefseed = 0xC0FFEE [signal] @@ -195,10 +196,14 @@ class pam_generator: def initialize_fiber_and_data(config): + f0 = config["glova"].get("f0", None) + if f0 is None: + f0 = 299792458/(config["glova"].get("lambda0", 1550)*1e-9) + config["glova"]["f0"] = f0 py_glova = pypho.setup( nos=config["glova"]["nos"], sps=config["glova"]["sps"], - f0=config["glova"]["f0"], + f0=f0, symbolrate=config["glova"]["symbolrate"], wisdom_dir=config["glova"]["wisdom_dir"], flags=config["glova"]["flags"], @@ -216,7 +221,9 @@ def initialize_fiber_and_data(config): symbolsrc = pypho.symbols( py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"] ) - laser = pypho.lasmod(py_glova, power=config["signal"]["laser_power"], Df=0, theta=np.pi / 4) + laserx = pypho.lasmod(py_glova, power=0, Df=0, theta=np.pi/4) + # lasery = pypho.lasmod(py_glova, power=0, Df=25, theta=0) + modulator = pam_generator( py_glova, mod_depth=config["signal"]["mod_depth"], @@ -232,7 +239,12 @@ def initialize_fiber_and_data(config): symbols_y[:3] = 0 # symbols_x += 1 - cw = laser() + + cw = laserx() + # cwy = lasery() + # cw[0]['E'][0] = cw[0]['E'][0] + # cw[0]['E'][1] = cwy[0]['E'][0] + source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y)) @@ -251,13 +263,41 @@ def initialize_fiber_and_data(config): # source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))] - nf = py_edfa.NF - source_signal = py_edfa(E=source_signal, NF=0) - py_edfa.NF = nf + ## side channels + # df = 100 + # signal_power = pypho.functions.W_to_dBm(np.sum(pypho.functions.getpower_W(source_signal[0]["E"]))) + + # symbols_x_side = symbolsrc(pattern="random") + # symbols_y_side = symbolsrc(pattern="random") + # symbols_x_side[:3] = 0 + # symbols_y_side[:3] = 0 + + # cw_left = laser(Df=-df) + # source_signal_left = modulator(E=cw_left, symbols=(symbols_x_side, symbols_y_side)) + + # cw_right = laser(Df=df) + # source_signal_right = modulator(E=cw_right, symbols=(symbols_y_side, symbols_x_side)) + + E_in_pure = source_signal[0]["E"] + + nf = py_edfa.NF + pmean = py_edfa.Pmean + + # ideal amplification to launch power into fiber + source_signal = py_edfa(E=source_signal, NF=0, Pmean=config["signal"]["laser_power"]) + # source_signal_left = py_edfa(E=source_signal_left, NF=0, Pmean=config["signal"]["laser_power"]) + # source_signal_right = py_edfa(E=source_signal_right, NF=0, Pmean=config["signal"]["laser_power"]) + + # source_signal[0]["E"][0] += source_signal_left[0]["E"][0] + source_signal_right[0]["E"][0] + # source_signal[0]["E"][1] += source_signal_left[0]["E"][1] + source_signal_right[0]["E"][1] + c_data.E_in = source_signal[0]["E"] noise = source_signal[0]["noise"] + py_edfa.NF = nf + py_edfa.Pmean = pmean + py_fiber = pypho.fiber( glova=py_glova, l=config["fiber"]["length"], @@ -265,20 +305,29 @@ def initialize_fiber_and_data(config): gamma=config["fiber"]["gamma"], D=config["fiber"]["d"], S=config["fiber"]["s"], + phi_max=0.02, ) - if config["fiber"].get("birefsteps", 0) > 0: + + config["fiber"]["birefsteps"] = config["fiber"].get( + "birefsteps", config["fiber"]["length"] // config["fiber"].get("bireflength", config["fiber"]["length"]) + ) + if config["fiber"]["birefsteps"] > 0: + config["fiber"]["bireflength"] = config["fiber"].get("bireflength", config["fiber"]["length"] / config["fiber"]["birefsteps"]) seed = config["fiber"].get("birefseed", (int(time.time() * 1000)) % 2**32) py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre( - py_fiber.l, - py_fiber.l / config["fiber"]["birefsteps"], - # maxDeltaD=config["fiber"]["d"]/5, - maxDeltaBeta=config["fiber"].get("max_delta_beta", 0), + config["fiber"]["length"], + config["fiber"]["bireflength"], + maxDeltaBeta=config["fiber"].get("pmd_q", 0)/np.sqrt(2*config["fiber"]["bireflength"]), seed=seed, ) - c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200) + elif (dgd := config['fiber'].get('dgd', 0)) > 0: + py_fiber.birefarray = [ + pypho.birefringence_segment(z_point=0, angle=np.pi/2, delta_beta=1000*dgd/config["fiber"]["length"]) + ] + c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=config["fiber"]["length"] if py_fiber.gamma == 0 else 200) c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova) - return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y) + return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y), py_glova, E_in_pure def save_data(data, config, **metadata): @@ -316,8 +365,11 @@ def save_data(data, config, **metadata): f"D = {config['fiber']['d']}", f"S = {config['fiber']['s']}", f"birefsteps = {config['fiber'].get('birefsteps', 0)}", - f"max_delta_beta = {config['fiber'].get('max_delta_beta', 0)}", + f"pmd_q = {config['fiber'].get('pmd_q', 0)}", f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set", + f"dgd = {config['fiber'].get('dgd', 0)}", + f"ortho_error = {config['fiber'].get('ortho_error', 0)}", + f"pol_error = {config['fiber'].get('pol_error', 0)}", "", "[signal]", f"seed = {hex(seed)}" if seed else "; seed = not set", @@ -346,24 +398,12 @@ def save_data(data, config, **metadata): save_file = f"{config_hash}.h5" config_content += f'"{str(save_file)}"\n' - filename_components = ( - timestamp.strftime("%Y%m%d-%H%M%S"), - config["glova"]["sps"], - config["glova"]["nos"], - config["signal"]["osnr"], - config["fiber"]["length"], - config["fiber"]["gamma"], - config["fiber"]["alpha"], - config["fiber"]["d"], - config["fiber"]["s"], - f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}", - config["fiber"].get("birefsteps", 0), - config["fiber"].get("max_delta_beta", 0), - int(config["glova"]["symbolrate"] / 1e9), - ) + config_filename:Path = create_config_filename(config, data_dir, timestamp) + while config_filename.exists(): + time.sleep(1) + config_filename = create_config_filename(config, data_dir=data_dir) + - lookup_file = "-".join(map(str, filename_components)) + ".ini" - config_filename = data_dir / lookup_file with open(config_filename, "w") as f: f.write(config_content) @@ -376,11 +416,31 @@ def save_data(data, config, **metadata): outfile.attrs[key] = value # np.save(save_dir / save_file, save_data) - print("Saved config to", config_filename) - print("Saved data to", save_dir / save_file) + # print("Saved config to", config_filename) + # print("Saved data to", save_dir / save_file) return config_filename +def create_config_filename(config, data_dir:Path, timestamp=None): + if timestamp is None: + timestamp = datetime.now() + filename_components = ( + timestamp.strftime("%Y%m%d-%H%M%S"), + config["glova"]["sps"], + config["glova"]["nos"], + config["signal"]["osnr"], + config["fiber"]["length"], + config["fiber"]["gamma"], + config["fiber"]["alpha"], + config["fiber"]["d"], + config["fiber"]["s"], + f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}", + config["fiber"].get("birefsteps", 0), + config["fiber"].get("pmd_q", 0), + int(config["glova"]["symbolrate"] / 1e9), + ) + lookup_file = "-".join(map(str, filename_components)) + ".ini" + return data_dir / lookup_file def length_loop(config, lengths, save=True): lengths = sorted(lengths) @@ -388,7 +448,7 @@ def length_loop(config, lengths, save=True): print(f"\nGenerating data for fiber length {length}m") config["fiber"]["length"] = length - cfiber, cdata, noise, edfa = initialize_fiber_and_data(config) + cfiber, cdata, noise, edfa, symbols, py_glova = initialize_fiber_and_data(config) mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in)) cfiber() @@ -416,51 +476,49 @@ def single_run_with_plot(config, save=True): in_out_eyes(cfiber, cdata, show_pols=False) return config_filename -def single_run(config, save=True): - cfiber, cdata, noise, edfa, symbols = initialize_fiber_and_data(config) - # mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in)) - # print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)") - - # estimate osnr - # noise_power = np.mean(noise) - # osnr_lin = mean_power_in / noise_power - 1 - # osnr = 10 * np.log10(osnr_lin) - # print(f"Estimated OSNR: {osnr:.3f} dB") +def single_run(config, save=True, silent=True): + cfiber, cdata, noise, edfa, symbols, glova, E_in = initialize_fiber_and_data(config) + # transmit cfiber() - # mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out)) - # print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)") - - # noise = noise * np.exp(-cfiber.params.l * cfiber.params.alpha) - - # estimate osnr - # noise_power = np.mean(noise) - # osnr_lin = mean_power_out / noise_power - 1 - # osnr = 10 * np.log10(osnr_lin) - # print(f"Estimated OSNR: {osnr:.3f} dB") - + # amplify E_tmp = [{"E": cdata.E_out, "noise": noise}] + E_tmp = edfa(E=E_tmp) + + + # rotate + # ortho error + ortho_error = config["fiber"].get("ortho_error", 0) + + E_tmp[0]["E"] = np.stack(( + E_tmp[0]["E"][0] * np.cos(ortho_error/2) + E_tmp[0]["E"][1] * np.sin(ortho_error/2), + E_tmp[0]["E"][0] * np.sin(ortho_error/2) + E_tmp[0]["E"][1] * np.cos(ortho_error/2) + ), axis=0) + + + pol_error = config['fiber'].get('pol_error', 0) + + E_tmp[0]["E"] = np.stack(( + E_tmp[0]["E"][0] * np.cos(pol_error) - E_tmp[0]["E"][1] * np.sin(pol_error), + E_tmp[0]["E"][0] * np.sin(pol_error) + E_tmp[0]["E"][1] * np.cos(pol_error) + ), axis=0) + + + + + # output cdata.E_out = E_tmp[0]["E"] - # noise = E_tmp[0]["noise"] - - # mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out)) - - # print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)") - - # estimate osnr - # noise_power = np.mean(noise) - # osnr_lin = mean_power_amp / noise_power - 1 - # osnr = 10 * np.log10(osnr_lin) - # print(f"Estimated OSNR: {osnr:.3f} dB") config_filename = None symbols = np.array(symbols) if save: config_filename = save_data(cdata, config, **{"symbols": symbols}) - return cfiber,cdata,config_filename + if not silent: + print(f"Saved config to {config_filename}") + return cfiber, cdata, config_filename def in_out_eyes(cfiber, cdata, show_pols=False): diff --git a/src/single-core-regen/testing/prob_dens.ipynb b/src/single-core-regen/testing/prob_dens.ipynb new file mode 100644 index 0000000..883407a --- /dev/null +++ b/src/single-core-regen/testing/prob_dens.ipynb @@ -0,0 +1,138 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Probability density of pmd for different base distribution" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from matplotlib import pyplot as plt\n", + "\n", + "trials = 10000\n", + "segments = 10000\n", + "# length = 100000\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Phase and amplitude, uniform distribution" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "phase = np.random.rand(trials) * 2 * np.pi\n", + "amp = np.random.rand(trials)\n", + "\n", + "res = amp * np.cos(phase) - amp*np.sin(phase)\n", + "res = np.abs(res)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAGdCAYAAAD+JxxnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAArIklEQVR4nO3df3RU9Z3/8Vd+MBMQJgiYhJQgKCpGMRyCxKm1HmxkVlO3rHiMPw7NItRFA0eI5deqRGy34aBb0RJh1a7xjyKQnuquBEPZYHArUTSQU1BgtUBhFyfg0WQiQhKSz/cPv7llIJBMyGSS+Twf59xznHvf9877k5TMq5/7Y2KMMUYAAAAWio10AwAAAJFCEAIAANYiCAEAAGsRhAAAgLUIQgAAwFoEIQAAYC2CEAAAsBZBCAAAWCs+0g30Zq2trTp69KgGDRqkmJiYSLcDAAA6wRijhoYGpaamKjb2wnM+BKELOHr0qNLS0iLdBgAA6IIjR45oxIgRF6whCF3AoEGDJH33g/R4PBHuBgAAdEYgEFBaWprzOX4hBKELaDsd5vF4CEIAAPQxnbmshYulAQCAtQhCAADAWgQhAABgLYIQAACwFkEIAABYiyAEAACsRRACAADWIggBAABrEYQAAIC1CEIAAMBaBCEAAGAtghAAALAWQQgAAFiLIAQAAKxFEAIAANYiCAEAAGsRhAAAgLUIQgAAwFoEIQAAYC2CEAAAsBZBCAAAWIsgBAAArEUQAgAA1iIIAQAAaxGEAACAtQhCAADAWgQhAABgLYIQAACwFkEIAABYiyAEAACsRRACAADWIggBAABrEYQAAIC1CEIAAMBaBCEAAGAtghAAALAWQQgAAFiLIAQAAKxFEAIAANYiCAEAAGsRhAAAgLVCCkJPP/20YmJigpaxY8c620+dOqX8/HwNHTpUAwcO1LRp01RbWxt0jMOHDysnJ0cDBgxQUlKSFixYoNOnTwfVVFZWasKECXK73RozZoxKSkrO6aW4uFijRo1SQkKCsrKytGPHjqDtnekFAADYLeQZoeuuu05ffPGFs/zpT39yts2fP19vv/22SktLtW3bNh09elR33323s72lpUU5OTlqamrS9u3b9frrr6ukpERLly51ag4ePKicnBxNnjxZNTU1mjdvnmbNmqXNmzc7NevXr1dBQYEKCwu1c+dOZWRkyOfz6dixY53uBQAAQCYEhYWFJiMjo91tdXV1pl+/fqa0tNRZt3fvXiPJVFVVGWOM2bRpk4mNjTV+v9+pWb16tfF4PKaxsdEYY8zChQvNddddF3Ts3Nxc4/P5nNeTJk0y+fn5zuuWlhaTmppqioqKOt1LZ9TX1xtJpr6+vtP7AACAyArl8zvkGaHPPvtMqampuuKKK/Tggw/q8OHDkqTq6mo1NzcrOzvbqR07dqxGjhypqqoqSVJVVZXGjRun5ORkp8bn8ykQCOiTTz5xas48RltN2zGamppUXV0dVBMbG6vs7GynpjO9tKexsVGBQCBoAQAA0SukIJSVlaWSkhKVl5dr9erVOnjwoG655RY1NDTI7/fL5XJp8ODBQfskJyfL7/dLkvx+f1AIatvetu1CNYFAQCdPntSXX36plpaWdmvOPEZHvbSnqKhIiYmJzpKWlta5HwwAAOiT4kMpvuOOO5z/vuGGG5SVlaXLL79cGzZsUP/+/bu9uZ62ZMkSFRQUOK8DgQBhCACAKHZRt88PHjxYV199tT7//HOlpKSoqalJdXV1QTW1tbVKSUmRJKWkpJxz51bb645qPB6P+vfvr2HDhikuLq7dmjOP0VEv7XG73fJ4PEELAACIXhcVhL755hv95S9/0fDhw5WZmal+/fqpoqLC2b5//34dPnxYXq9XkuT1erV79+6gu7u2bNkij8ej9PR0p+bMY7TVtB3D5XIpMzMzqKa1tVUVFRVOTWd6AQAACOmusccff9xUVlaagwcPmvfff99kZ2ebYcOGmWPHjhljjJk9e7YZOXKk2bp1q/n444+N1+s1Xq/X2f/06dPm+uuvN1OmTDE1NTWmvLzcXHbZZWbJkiVOzYEDB8yAAQPMggULzN69e01xcbGJi4sz5eXlTs26deuM2+02JSUl5tNPPzUPP/ywGTx4cNDdaB310hncNQYAQN8Tyud3SEEoNzfXDB8+3LhcLvO9733P5Obmms8//9zZfvLkSfPoo4+aSy+91AwYMMD8wz/8g/niiy+CjnHo0CFzxx13mP79+5thw4aZxx9/3DQ3NwfVvPvuu2b8+PHG5XKZK664wrz22mvn9PKb3/zGjBw50rhcLjNp0iTzwQcfBG3vTC8dIQgBAND3hPL5HWOMMZGdk+q9AoGAEhMTVV9fz/VCAAD0EaF8fvNdYwAAwFoEIQAAYC2CEAAAsBZBCAAAWIsgBAAArEUQAgAA1iIIAQAAaxGEAACAtQhCAADAWgQhAABgLYIQAACwFkEIAABYiyAEAACsRRACAADWIggBAABrEYQAAIC1CEIAAMBaBCEAAGAtghAAALAWQSiCRi0ui3QLAABYjSAEAACsRRACAADWIggBAABrEYQAAIC1CEIAAMBaBCEAAGAtghAAALAWQQgAAFiLIAQAAKxFEAIAANYiCAEAAGsRhAAAgLUIQgAAwFoEIQAAYC2CEAAAsBZBCAAAWIsgBAAArEUQAgAA1iIIAQAAaxGEAACAtQhCAADAWgQhAABgLYIQAACwFkEIAABYiyAEAACsRRACAADWIggBAABrEYQAAIC1CEIAAMBaBCEAAGAtghAAALAWQQgAAFiLIAQAAKx1UUFo+fLliomJ0bx585x1p06dUn5+voYOHaqBAwdq2rRpqq2tDdrv8OHDysnJ0YABA5SUlKQFCxbo9OnTQTWVlZWaMGGC3G63xowZo5KSknPev7i4WKNGjVJCQoKysrK0Y8eOoO2d6QUAANiry0Hoo48+0r/927/phhtuCFo/f/58vf322yotLdW2bdt09OhR3X333c72lpYW5eTkqKmpSdu3b9frr7+ukpISLV261Kk5ePCgcnJyNHnyZNXU1GjevHmaNWuWNm/e7NSsX79eBQUFKiws1M6dO5WRkSGfz6djx451uhcAAGA50wUNDQ3mqquuMlu2bDG33nqreeyxx4wxxtTV1Zl+/fqZ0tJSp3bv3r1GkqmqqjLGGLNp0yYTGxtr/H6/U7N69Wrj8XhMY2OjMcaYhQsXmuuuuy7oPXNzc43P53NeT5o0yeTn5zuvW1paTGpqqikqKup0Lx2pr683kkx9fX2n6kN1+aKN5vJFG8NybAAAbBXK53eXZoTy8/OVk5Oj7OzsoPXV1dVqbm4OWj927FiNHDlSVVVVkqSqqiqNGzdOycnJTo3P51MgENAnn3zi1Jx9bJ/P5xyjqalJ1dXVQTWxsbHKzs52ajrTy9kaGxsVCASCFgAAEL3iQ91h3bp12rlzpz766KNztvn9frlcLg0ePDhofXJysvx+v1NzZghq29627UI1gUBAJ0+e1Ndff62WlpZ2a/bt29fpXs5WVFSkZcuWXWD0AAAgmoQ0I3TkyBE99thj+t3vfqeEhIRw9RQxS5YsUX19vbMcOXIk0i0BAIAwCikIVVdX69ixY5owYYLi4+MVHx+vbdu26cUXX1R8fLySk5PV1NSkurq6oP1qa2uVkpIiSUpJSTnnzq221x3VeDwe9e/fX8OGDVNcXFy7NWceo6NezuZ2u+XxeIIWAAAQvUIKQj/60Y+0e/du1dTUOMvEiRP14IMPOv/dr18/VVRUOPvs379fhw8fltfrlSR5vV7t3r076O6uLVu2yOPxKD093ak58xhtNW3HcLlcyszMDKppbW1VRUWFU5OZmdlhLwAAwG4hXSM0aNAgXX/99UHrLrnkEg0dOtRZP3PmTBUUFGjIkCHyeDyaO3euvF6vbrrpJknSlClTlJ6erunTp2vFihXy+/168sknlZ+fL7fbLUmaPXu2Vq1apYULF+qhhx7S1q1btWHDBpWVlTnvW1BQoLy8PE2cOFGTJk3SypUrdeLECc2YMUOSlJiY2GEvAADAbiFfLN2R559/XrGxsZo2bZoaGxvl8/n00ksvOdvj4uK0ceNGPfLII/J6vbrkkkuUl5enZ555xqkZPXq0ysrKNH/+fL3wwgsaMWKEXn31Vfl8PqcmNzdXx48f19KlS+X3+zV+/HiVl5cHXUDdUS8AAMBuMcYYE+kmeqtAIKDExETV19eH5XqhUYu/m+E6tDyn248NAICtQvn85rvGAACAtQhCAADAWgQhAABgLYIQAACwFkEIAABYiyAEAACsRRACAADWIggBAABrEYQAAIC1CEK9QNsTpgEAQM8iCAEAAGsRhAAAgLUIQgAAwFoEIQAAYC2CEAAAsBZBCAAAWIsgBAAArEUQAgAA1iIIAQAAaxGEAACAtQhCAADAWgQhAABgLYIQAACwFkEIAABYiyAEAACsRRACAADWIggBAABrEYR6iVGLyyLdAgAA1iEIAQAAaxGEAACAtQhCAADAWgQhAABgLYIQAACwFkEIAABYiyAEAACsRRACAADWIgj1IjxUEQCAnkUQAgAA1iIIAQAAaxGEAACAtQhCAADAWgQhAABgLYIQAACwFkGol+EWegAAeg5BCAAAWIsgBAAArEUQAgAA1iIIAQAAaxGEAACAtQhCvRB3jgEA0DMIQgAAwFoEIQAAYC2CEAAAsFZIQWj16tW64YYb5PF45PF45PV69c477zjbT506pfz8fA0dOlQDBw7UtGnTVFtbG3SMw4cPKycnRwMGDFBSUpIWLFig06dPB9VUVlZqwoQJcrvdGjNmjEpKSs7ppbi4WKNGjVJCQoKysrK0Y8eOoO2d6QUAANgtpCA0YsQILV++XNXV1fr4449122236Sc/+Yk++eQTSdL8+fP19ttvq7S0VNu2bdPRo0d19913O/u3tLQoJydHTU1N2r59u15//XWVlJRo6dKlTs3BgweVk5OjyZMnq6amRvPmzdOsWbO0efNmp2b9+vUqKChQYWGhdu7cqYyMDPl8Ph07dsyp6aiX3o4LpgEACL8YY4y5mAMMGTJEzz77rO655x5ddtllWrt2re655x5J0r59+3TttdeqqqpKN910k9555x39+Mc/1tGjR5WcnCxJWrNmjRYtWqTjx4/L5XJp0aJFKisr0549e5z3uO+++1RXV6fy8nJJUlZWlm688UatWrVKktTa2qq0tDTNnTtXixcvVn19fYe9dEYgEFBiYqLq6+vl8Xgu5sfUro7CzqHlOd3+ngAARLtQPr+7fI1QS0uL1q1bpxMnTsjr9aq6ulrNzc3Kzs52asaOHauRI0eqqqpKklRVVaVx48Y5IUiSfD6fAoGAM6tUVVUVdIy2mrZjNDU1qbq6OqgmNjZW2dnZTk1negEAAIgPdYfdu3fL6/Xq1KlTGjhwoN58802lp6erpqZGLpdLgwcPDqpPTk6W3++XJPn9/qAQ1La9bduFagKBgE6ePKmvv/5aLS0t7dbs27fPOUZHvbSnsbFRjY2NzutAINDBTwMAAPRlIc8IXXPNNaqpqdGHH36oRx55RHl5efr000/D0VuPKyoqUmJiorOkpaVFuiUAABBGIQchl8ulMWPGKDMzU0VFRcrIyNALL7yglJQUNTU1qa6uLqi+trZWKSkpkqSUlJRz7txqe91RjcfjUf/+/TVs2DDFxcW1W3PmMTrqpT1LlixRfX29sxw5cqRzPxQAANAnXfRzhFpbW9XY2KjMzEz169dPFRUVzrb9+/fr8OHD8nq9kiSv16vdu3cH3d21ZcsWeTwepaenOzVnHqOtpu0YLpdLmZmZQTWtra2qqKhwajrTS3vcbrfzaIC2BQAARK+QrhFasmSJ7rjjDo0cOVINDQ1au3atKisrtXnzZiUmJmrmzJkqKCjQkCFD5PF4NHfuXHm9XucurSlTpig9PV3Tp0/XihUr5Pf79eSTTyo/P19ut1uSNHv2bK1atUoLFy7UQw89pK1bt2rDhg0qK/vbHVYFBQXKy8vTxIkTNWnSJK1cuVInTpzQjBkzJKlTvfQFoxaXcecYAABhFFIQOnbsmH7605/qiy++UGJiom644QZt3rxZt99+uyTp+eefV2xsrKZNm6bGxkb5fD699NJLzv5xcXHauHGjHnnkEXm9Xl1yySXKy8vTM88849SMHj1aZWVlmj9/vl544QWNGDFCr776qnw+n1OTm5ur48ePa+nSpfL7/Ro/frzKy8uDLqDuqBcAAICLfo5QNIv0c4QkniUEAECoeuQ5QgAAAH0dQaiX46s2AAAIH4IQAACwFkEIAABYiyAEAACsRRDqA7hOCACA8CAIAQAAaxGEAACAtQhCfQSnxwAA6H4EIQAAYC2CEAAAsBZBqA/h9BgAAN2LIAQAAKxFEAIAANYiCAEAAGsRhPoYrhMCAKD7EIQAAIC1CEIAAMBaBKE+iNNjAAB0D4IQAACwFkGoj2JWCACAi0cQAgAA1iIIAQAAaxGE+jBOjwEAcHEIQgAAwFoEoT6OWSEAALqOIAQAAKxFEAIAANYiCEUBTo8BANA1BCEAAGAtglCUYFYIAIDQEYQAAIC1CEJRhFkhAABCQxACAADWIghFGWaFAADoPIIQAACwFkEoCjErBABA5xCEohRhCACAjhGEAACAtQhCUYxZIQAALowgBAAArEUQinLMCgEAcH4EIQAAYC2CkAWYFQIAoH0EIUsQhgAAOBdBCAAAWIsgZBFmhQAACEYQsgxhCACAvyEIWYgwBADAdwhCAADAWgQhS41aXMbMEADAegQhAABgLYKQ5ZgVAgDYjCAEwhAAwFohBaGioiLdeOONGjRokJKSkjR16lTt378/qObUqVPKz8/X0KFDNXDgQE2bNk21tbVBNYcPH1ZOTo4GDBigpKQkLViwQKdPnw6qqays1IQJE+R2uzVmzBiVlJSc009xcbFGjRqlhIQEZWVlaceOHSH3gu8QhgAANgopCG3btk35+fn64IMPtGXLFjU3N2vKlCk6ceKEUzN//ny9/fbbKi0t1bZt23T06FHdfffdzvaWlhbl5OSoqalJ27dv1+uvv66SkhItXbrUqTl48KBycnI0efJk1dTUaN68eZo1a5Y2b97s1Kxfv14FBQUqLCzUzp07lZGRIZ/Pp2PHjnW6FwQjDAEAbBNjjDFd3fn48eNKSkrStm3b9MMf/lD19fW67LLLtHbtWt1zzz2SpH379unaa69VVVWVbrrpJr3zzjv68Y9/rKNHjyo5OVmStGbNGi1atEjHjx+Xy+XSokWLVFZWpj179jjvdd9996murk7l5eWSpKysLN14441atWqVJKm1tVVpaWmaO3euFi9e3KleOhIIBJSYmKj6+np5PJ6u/pjOqzcGj0PLcyLdAgAAFyWUz++Lukaovr5ekjRkyBBJUnV1tZqbm5Wdne3UjB07ViNHjlRVVZUkqaqqSuPGjXNCkCT5fD4FAgF98sknTs2Zx2iraTtGU1OTqqurg2piY2OVnZ3t1HSml7M1NjYqEAgELbbhtnoAgE26HIRaW1s1b9483Xzzzbr++uslSX6/Xy6XS4MHDw6qTU5Olt/vd2rODEFt29u2XagmEAjo5MmT+vLLL9XS0tJuzZnH6KiXsxUVFSkxMdFZ0tLSOvnTAAAAfVGXg1B+fr727NmjdevWdWc/EbVkyRLV19c7y5EjRyLdUsQwKwQAsEGXgtCcOXO0ceNGvfvuuxoxYoSzPiUlRU1NTaqrqwuqr62tVUpKilNz9p1bba87qvF4POrfv7+GDRumuLi4dmvOPEZHvZzN7XbL4/EELTYjDAEAol1IQcgYozlz5ujNN9/U1q1bNXr06KDtmZmZ6tevnyoqKpx1+/fv1+HDh+X1eiVJXq9Xu3fvDrq7a8uWLfJ4PEpPT3dqzjxGW03bMVwulzIzM4NqWltbVVFR4dR0phd0jGuGAADRLD6U4vz8fK1du1b/8R//oUGDBjnX2iQmJqp///5KTEzUzJkzVVBQoCFDhsjj8Wju3Lnyer3OXVpTpkxRenq6pk+frhUrVsjv9+vJJ59Ufn6+3G63JGn27NlatWqVFi5cqIceekhbt27Vhg0bVFb2tw/kgoIC5eXlaeLEiZo0aZJWrlypEydOaMaMGU5PHfWCzhu1uIw7ygAAUSek2+djYmLaXf/aa6/pH//xHyV99xDDxx9/XG+88YYaGxvl8/n00ksvBZ2O+utf/6pHHnlElZWVuuSSS5SXl6fly5crPv5vuayyslLz58/Xp59+qhEjRuipp55y3qPNqlWr9Oyzz8rv92v8+PF68cUXlZWV5WzvTC8XYuPt8x0hDAEAertQPr8v6jlC0Y4gdC6CEACgt+ux5wjBPn0xvAEAcD4EIYSMMAQAiBYEIXQJd5MBAKIBQQgXhUAEAOjLCEIAAMBaBCF0C2aGAAB9EUEI3YowBADoSwhC6HaEIQBAX0EQQlgQhgAAfQFBCGHDdUMAgN6OIISwIwwBAHqrkL59HuiqM8MQ31cGAOgtmBFCj2OGCADQWxCEEBGEIQBAb0AQQsQQhgAAkUYQQkRxZxkAIJIIQugVCEQAgEggCKFXIRABAHoSQQi9EmEIANATCELotZgdAgCEG0EIvR6BCAAQLgQh9BkEIgBAdyMIoc8hEAEAugtBCH0WYQgAcLH40lX0aXyZKwDgYjAjhKjBKTMAQKgIQog6BCIAQGdxagxRi9NmAICOMCMEKzBLBABoD0EIViEQAQDORBCClQhEAACJa4RgOa4jAgC7MSME/H/MEAGAfZgRAs5wdhhilggAohszQsAFMEsEANGNGSGgA1xHBADRiyAEhIBQBADRhSAEdBGhCAD6Pq4RAroBzyUCgL6JIAR0IwIRAPQtnBoDwoDb8AGgbyAIAT2A64kAoHfi1BjQwzh9BgC9BzNCQIQwSwQAkUcQAnqB9maICEcAEH6cGgN6KU6fAUD4MSME9GLcfQYA4UUQAvoQrisCgO5FEAL6KEIRAFw8ghAQBTiFBgBdQxACohCzRQDQOQQhIMpxaz4AnB9BCLAQM0YA8J2QnyP03nvv6a677lJqaqpiYmL01ltvBW03xmjp0qUaPny4+vfvr+zsbH322WdBNV999ZUefPBBeTweDR48WDNnztQ333wTVPPnP/9Zt9xyixISEpSWlqYVK1ac00tpaanGjh2rhIQEjRs3Tps2bQq5F8B2bV/5wVd/ALBRyEHoxIkTysjIUHFxcbvbV6xYoRdffFFr1qzRhx9+qEsuuUQ+n0+nTp1yah588EF98skn2rJlizZu3Kj33ntPDz/8sLM9EAhoypQpuvzyy1VdXa1nn31WTz/9tF5++WWnZvv27br//vs1c+ZM7dq1S1OnTtXUqVO1Z8+ekHoBEIxQBMAmMcYY0+WdY2L05ptvaurUqZK+m4FJTU3V448/rp///OeSpPr6eiUnJ6ukpET33Xef9u7dq/T0dH300UeaOHGiJKm8vFx33nmn/vd//1epqalavXq1nnjiCfn9frlcLknS4sWL9dZbb2nfvn2SpNzcXJ04cUIbN250+rnppps0fvx4rVmzplO9dCQQCCgxMVH19fXyeDxd/TGdFx806Gs4jQagLwjl87tbv2Lj4MGD8vv9ys7OdtYlJiYqKytLVVVVkqSqqioNHjzYCUGSlJ2drdjYWH344YdOzQ9/+EMnBEmSz+fT/v379fXXXzs1Z75PW03b+3Sml7M1NjYqEAgELQD+htNoAKJNt14s7ff7JUnJyclB65OTk51tfr9fSUlJwU3Ex2vIkCFBNaNHjz7nGG3bLr30Uvn9/g7fp6NezlZUVKRly5Z1brAAeH4RgD6PL109w5IlS1RfX+8sR44ciXRLQJ/S3mwRM0cAerNunRFKSUmRJNXW1mr48OHO+traWo0fP96pOXbsWNB+p0+f1ldffeXsn5KSotra2qCattcd1Zy5vaNezuZ2u+V2uzs9XgDn114YYsYIQG/TrTNCo0ePVkpKiioqKpx1gUBAH374obxeryTJ6/Wqrq5O1dXVTs3WrVvV2tqqrKwsp+a9995Tc3OzU7NlyxZdc801uvTSS52aM9+nrabtfTrTC4CexTVGAHqbkGeEvvnmG33++efO64MHD6qmpkZDhgzRyJEjNW/ePP3yl7/UVVddpdGjR+upp55Samqqc2fZtddeq7/7u7/Tz372M61Zs0bNzc2aM2eO7rvvPqWmpkqSHnjgAS1btkwzZ87UokWLtGfPHr3wwgt6/vnnnfd97LHHdOutt+pf//VflZOTo3Xr1unjjz92brGPiYnpsBcAkXW+p16PWlzG7BGAHhHy7fOVlZWaPHnyOevz8vJUUlIiY4wKCwv18ssvq66uTj/4wQ/00ksv6eqrr3Zqv/rqK82ZM0dvv/22YmNjNW3aNL344osaOHCgU/PnP/9Z+fn5+uijjzRs2DDNnTtXixYtCnrP0tJSPfnkkzp06JCuuuoqrVixQnfeeaezvTO9XAi3zwO9A6EIQChC+fy+qOcIRTuCENA7EYwAXEgon9981xiAPocvkgXQXQhCAKICzzQC0BUEIQBRiWAEoDMIQgCscL5r8ghIgN0IQgCsxswRYDeCEACcgZkjwC4EIQDoBGaOgOhEEAKALuAWfiA6EIQAoJsQjoC+hyAEAGHENUdA70YQAoAI4JojoHcgCAFAL8DMERAZBCEA6MW47ggIL4IQAPQxzB4B3YcgBABRgoAEhC420g0AAMJr1OIyZ2l7DeA7zAgBgEUuFIaYOYKNCEIAAEmEI9iJIAQAOK/zhaNRi8sISYgKBCEAQEg4vYZoQhACAHQbwhH6GoIQACCsCEfozQhCAIAed6Fb+AlJ6EkEIQBAr8IX0qInEYQAAL1ae8GIu9bQXQhCAIA+paMnZBOQEAqCEAAgqnBxNkJBEAIARD3CEc6HIAQAsBLhCBJBCAAAB+HIPgQhAAAugHAU3QhCAACEiDvWogdBCACAbsLsUd9DEAIAIIwIR70bQQgAgB5GOOo9CEIAAPQChKPIIAgBANBL8QW04UcQAgCgj2DWqPsRhAAA6MOYNbo4BCEAAKIIs0ahIQgBABDlzgxHhKJgBCEAACzCqbRgBCEAACxmezAiCAEAAIdtp9EIQgAAoF02hCKCEAAA6FC0nkKLjXQDAACg7xm1uKzdW/X7GmaEAABAl/X102cEIQAA0C36Yiji1BgAAOh2feXUGUEIAACETW8PQwQhAAAQVr15doggBAAAekRvDERWBKHi4mKNGjVKCQkJysrK0o4dOyLdEgAA1upNYSjqg9D69etVUFCgwsJC7dy5UxkZGfL5fDp27FikWwMAwFq9JQxFfRD69a9/rZ/97GeaMWOG0tPTtWbNGg0YMED//u//HunWAABAhEX1c4SamppUXV2tJUuWOOtiY2OVnZ2tqqqqc+obGxvV2NjovK6vr5ckBQKBsPTX2vhtWI4LAEBfMHJ+qfYs83X7cds+t40xHdZGdRD68ssv1dLSouTk5KD1ycnJ2rdv3zn1RUVFWrZs2Tnr09LSwtYjAAA2S1wZvmM3NDQoMTHxgjVRHYRCtWTJEhUUFDivW1tb9dVXX2no0KGKiYnp1vcKBAJKS0vTkSNH5PF4uvXYfYHN47d57BLjZ/yM39bx9+TYjTFqaGhQampqh7VRHYSGDRumuLg41dbWBq2vra1VSkrKOfVut1tutzto3eDBg8PZojwej3X/GM5k8/htHrvE+Bk/47d1/D019o5mgtpE9cXSLpdLmZmZqqiocNa1traqoqJCXq83gp0BAIDeIKpnhCSpoKBAeXl5mjhxoiZNmqSVK1fqxIkTmjFjRqRbAwAAERb1QSg3N1fHjx/X0qVL5ff7NX78eJWXl59zAXVPc7vdKiwsPOdUnC1sHr/NY5cYP+Nn/LaOv7eOPcZ05t4yAACAKBTV1wgBAABcCEEIAABYiyAEAACsRRACAADWIgiFUXFxsUaNGqWEhARlZWVpx44dF6wvLS3V2LFjlZCQoHHjxmnTpk091Gl4hDL+V155RbfccosuvfRSXXrppcrOzu7w59Wbhfq7b7Nu3TrFxMRo6tSp4W0wzEIdf11dnfLz8zV8+HC53W5dffXVffp//6GOf+XKlbrmmmvUv39/paWlaf78+Tp16lQPddt93nvvPd11111KTU1VTEyM3nrrrQ73qays1IQJE+R2uzVmzBiVlJSEvc9wCXX8f/jDH3T77bfrsssuk8fjkdfr1ebNm3um2TDoyu+/zfvvv6/4+HiNHz8+bP2dD0EoTNavX6+CggIVFhZq586dysjIkM/n07Fjx9qt3759u+6//37NnDlTu3bt0tSpUzV16lTt2bOnhzvvHqGOv7KyUvfff7/effddVVVVKS0tTVOmTNH//d//9XDnFy/Usbc5dOiQfv7zn+uWW27poU7DI9TxNzU16fbbb9ehQ4f0+9//Xvv379crr7yi733vez3cefcIdfxr167V4sWLVVhYqL179+q3v/2t1q9fr3/+53/u4c4v3okTJ5SRkaHi4uJO1R88eFA5OTmaPHmyampqNG/ePM2aNavPhoFQx//ee+/p9ttv16ZNm1RdXa3Jkyfrrrvu0q5du8LcaXiEOv42dXV1+ulPf6of/ehHYeqsAwZhMWnSJJOfn++8bmlpMampqaaoqKjd+nvvvdfk5OQErcvKyjL/9E//FNY+wyXU8Z/t9OnTZtCgQeb1118PV4th05Wxnz592nz/+983r776qsnLyzM/+clPeqDT8Ah1/KtXrzZXXHGFaWpq6qkWwyrU8efn55vbbrstaF1BQYG5+eabw9pnuEkyb7755gVrFi5caK677rqgdbm5ucbn84Wxs57RmfG3Jz093Sxbtqz7G+phoYw/NzfXPPnkk6awsNBkZGSEta/2MCMUBk1NTaqurlZ2drazLjY2VtnZ2aqqqmp3n6qqqqB6SfL5fOet7826Mv6zffvtt2pubtaQIUPC1WZYdHXszzzzjJKSkjRz5syeaDNsujL+//zP/5TX61V+fr6Sk5N1/fXX61e/+pVaWlp6qu1u05Xxf//731d1dbVz+uzAgQPatGmT7rzzzh7pOZKi6e9ed2htbVVDQ0Of+7t3MV577TUdOHBAhYWFEesh6p8sHQlffvmlWlpaznl6dXJysvbt29fuPn6/v916v98ftj7DpSvjP9uiRYuUmpp6zh/J3q4rY//Tn/6k3/72t6qpqemBDsOrK+M/cOCAtm7dqgcffFCbNm3S559/rkcffVTNzc0R/ePYFV0Z/wMPPKAvv/xSP/jBD2SM0enTpzV79uw+eWosVOf7uxcIBHTy5En1798/Qp1FxnPPPadvvvlG9957b6Rb6RGfffaZFi9erP/+7/9WfHzk4ggzQuh1li9frnXr1unNN99UQkJCpNsJq4aGBk2fPl2vvPKKhg0bFul2IqK1tVVJSUl6+eWXlZmZqdzcXD3xxBNas2ZNpFvrEZWVlfrVr36ll156STt37tQf/vAHlZWV6Re/+EWkW0MPWrt2rZYtW6YNGzYoKSkp0u2EXUtLix544AEtW7ZMV199dUR7YUYoDIYNG6a4uDjV1tYGra+trVVKSkq7+6SkpIRU35t1ZfxtnnvuOS1fvlz/9V//pRtuuCGcbYZFqGP/y1/+okOHDumuu+5y1rW2tkqS4uPjtX//fl155ZXhbbobdeV3P3z4cPXr109xcXHOumuvvVZ+v19NTU1yuVxh7bk7dWX8Tz31lKZPn65Zs2ZJksaNG6cTJ07o4Ycf1hNPPKHY2Oj9/6vn+7vn8Xismg1at26dZs2apdLS0j43C95VDQ0N+vjjj7Vr1y7NmTNH0nd/+4wxio+P1x//+EfddtttPdJL9P4LiyCXy6XMzExVVFQ461pbW1VRUSGv19vuPl6vN6hekrZs2XLe+t6sK+OXpBUrVugXv/iFysvLNXHixJ5otduFOvaxY8dq9+7dqqmpcZa///u/d+6iSUtL68n2L1pXfvc333yzPv/8cycAStL//M//aPjw4X0qBEldG/+33357TthpC4Umyr8KMpr+7nXVG2+8oRkzZuiNN95QTk5OpNvpMR6P55y/fbNnz9Y111yjmpoaZWVl9VwzPX55tiXWrVtn3G63KSkpMZ9++ql5+OGHzeDBg43f7zfGGDN9+nSzePFip/7999838fHx5rnnnjN79+41hYWFpl+/fmb37t2RGsJFCXX8y5cvNy6Xy/z+9783X3zxhbM0NDREaghdFurYz9bX7xoLdfyHDx82gwYNMnPmzDH79+83GzduNElJSeaXv/xlpIZwUUIdf2FhoRk0aJB54403zIEDB8wf//hHc+WVV5p77703UkPosoaGBrNr1y6za9cuI8n8+te/Nrt27TJ//etfjTHGLF682EyfPt2pP3DggBkwYIBZsGCB2bt3rykuLjZxcXGmvLw8UkO4KKGO/3e/+52Jj483xcXFQX/36urqIjWEixLq+M8WqbvGCEJh9Jvf/MaMHDnSuFwuM2nSJPPBBx8422699VaTl5cXVL9hwwZz9dVXG5fLZa677jpTVlbWwx13r1DGf/nllxtJ5yyFhYU933g3CPV3f6a+HoSMCX3827dvN1lZWcbtdpsrrrjC/Mu//Is5ffp0D3fdfUIZf3Nzs3n66afNlVdeaRISEkxaWpp59NFHzddff93zjV+kd999t91/x23jzcvLM7feeus5+4wfP964XC5zxRVXmNdee63H++4uoY7/1ltvvWB9X9OV3/+ZIhWEYoyJ8rlXAACA8+AaIQAAYC2CEAAAsBZBCAAAWIsgBAAArEUQAgAA1iIIAQAAaxGEAACAtQhCAADAWgQhAABgLYIQAACwFkEIAABYiyAEAACs9f8A8CGPNODc0vgAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "_ = plt.hist(res, bins=min(trials//100, 1000))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Uniform phase, uniform amplitude" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [], + "source": [ + "delta_beta = np.random.randn(trials, segments)+0.5\n", + "phi = np.random.rand(trials, segments)*2*np.pi\n", + "\n", + "beta_x = delta_beta * np.cos(phi)\n", + "beta_y = delta_beta * np.sin(phi)\n", + "\n", + "delta_beta = np.abs(beta_x - beta_y)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlEUlEQVR4nO3dfXBU1cHH8V8SkuUtu2mQZMmQAKIFImAtQthKlUpKgNTqGKdiqWDLwEg3TiGtlVjq6zOG2heoVMBaC20HSnVGUEFBDBBqjaCpjLyUqIgFCxtUSpaXYSHJef7wYR83hJdN9mbPhu9n5s6w9569e84hbH6cc+69ScYYIwAAAIskx7sCAAAAzRFQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADW6RTvCrRGU1OTDhw4oPT0dCUlJcW7OgAA4CIYY3T06FHl5OQoOfn8YyQJGVAOHDig3NzceFcDAAC0wv79+9W7d+/zlknIgJKeni7p8wa63e441wYAAFyMYDCo3Nzc8O/x80nIgHJmWsftdhNQAABIMBezPINFsgAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADW6RTvCgBAc31nrzlr30dzi+NQEwDxwggKAACwDiMoAGKGkQ8AscIICgAAsA4BBQAAWIeAAgAArENAAQAA1mGRLIB2xUJaABeDgALAUS0FEgC4EKZ4AACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdnsUD4KLE+yF/zT+fBwwCHRsBBejg4h0sAKA1CCgAWo0nFQNwCmtQAACAdQgoAADAOgQUAABgHdagAIg71rIAaI6AAlyCuGQXgO2Y4gEAANYhoAAAAOswxQOAm7kBsE5UIyiLFi3S0KFD5Xa75Xa75fP59Morr4SPnzx5Un6/Xz169FD37t1VUlKiurq6iHPs27dPxcXF6tq1q7KysnTvvfeqoaEhNq0BAAAdQlQBpXfv3po7d65qamr09ttv68Ybb9TNN9+snTt3SpJmzZqll156Sc8995yqqqp04MAB3XrrreH3NzY2qri4WKdOndIbb7yhP/3pT1q6dKkeeOCB2LYKAAAktCRjjGnLCTIzM/XLX/5St912m3r27Knly5frtttukyTt3r1bgwYNUnV1tUaOHKlXXnlF3/rWt3TgwAFlZ2dLkhYvXqz77rtPn3zyidLS0i7qM4PBoDwej+rr6+V2u9tSfaDDa+0lvM2neGy7FJgpKCDxRPP7u9WLZBsbG7VixQodP35cPp9PNTU1On36tAoLC8NlBg4cqLy8PFVXV0uSqqurNWTIkHA4kaSioiIFg8HwKExLQqGQgsFgxAYAADquqAPK9u3b1b17d7lcLt19991auXKl8vPzFQgElJaWpoyMjIjy2dnZCgQCkqRAIBARTs4cP3PsXCoqKuTxeMJbbm5utNUGAAAJJOqAMmDAAG3btk1btmzRjBkzNGXKFO3atcuJuoWVl5ervr4+vO3fv9/RzwMAAPEV9WXGaWlpuuKKKyRJw4YN01tvvaXf/va3uv3223Xq1CkdOXIkYhSlrq5OXq9XkuT1erV169aI8525yudMmZa4XC65XK5oqwoAABJUm2/U1tTUpFAopGHDhik1NVWVlZXhY7W1tdq3b598Pp8kyefzafv27Tp06FC4zPr16+V2u5Wfn9/WqgAAgA4iqhGU8vJyjR8/Xnl5eTp69KiWL1+uTZs2ad26dfJ4PJo6darKysqUmZkpt9ute+65Rz6fTyNHjpQkjR07Vvn5+brzzjv1+OOPKxAIaM6cOfL7/YyQAACAsKgCyqFDhzR58mQdPHhQHo9HQ4cO1bp16/TNb35TkjRv3jwlJyerpKREoVBIRUVFWrhwYfj9KSkpWr16tWbMmCGfz6du3bppypQpeuSRR2LbKgAAkNDafB+UeOA+KMDF4z4oAGwRze9vnsUDICHx/CCgY+NpxgAAwDqMoABokW1TOgAuLYygAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANbpFO8KAIidvrPXxLsKABATjKAAAADrEFAAAIB1mOIBEhhTOgA6KkZQAACAdQgoAADAOgQUAABgHdagAOgwmq/J+WhucZxqAqCtGEEBAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWCeqgFJRUaHhw4crPT1dWVlZuuWWW1RbWxtRZvTo0UpKSorY7r777ogy+/btU3Fxsbp27aqsrCzde++9amhoaHtrAABAh9ApmsJVVVXy+/0aPny4GhoadP/992vs2LHatWuXunXrFi43bdo0PfLII+HXXbt2Df+5sbFRxcXF8nq9euONN3Tw4EFNnjxZqampeuyxx2LQJAAAkOiiCihr166NeL106VJlZWWppqZG119/fXh/165d5fV6WzzHq6++ql27dum1115Tdna2vvKVr+jRRx/Vfffdp4ceekhpaWmtaAYAAOhI2rQGpb6+XpKUmZkZsX/ZsmW67LLLNHjwYJWXl+vEiRPhY9XV1RoyZIiys7PD+4qKihQMBrVz584WPycUCikYDEZsQEfXd/aaiA0ALiVRjaB8UVNTk2bOnKnrrrtOgwcPDu//7ne/qz59+ignJ0fvvvuu7rvvPtXW1ur555+XJAUCgYhwIin8OhAItPhZFRUVevjhh1tbVQCXqJaC3Udzi+NQEwDRanVA8fv92rFjh15//fWI/dOnTw//eciQIerVq5fGjBmjPXv2qH///q36rPLycpWVlYVfB4NB5ebmtq7iAADAeq2a4iktLdXq1au1ceNG9e7d+7xlCwoKJEkffPCBJMnr9aquri6izJnX51q34nK55Ha7IzYAANBxRRVQjDEqLS3VypUrtWHDBvXr1++C79m2bZskqVevXpIkn8+n7du369ChQ+Ey69evl9vtVn5+fjTVAQAAHVRUUzx+v1/Lly/XCy+8oPT09PCaEY/Hoy5dumjPnj1avny5JkyYoB49eujdd9/VrFmzdP3112vo0KGSpLFjxyo/P1933nmnHn/8cQUCAc2ZM0d+v18ulyv2LQQAAAknqhGURYsWqb6+XqNHj1avXr3C29/+9jdJUlpaml577TWNHTtWAwcO1I9//GOVlJTopZdeCp8jJSVFq1evVkpKinw+n773ve9p8uTJEfdNAQAAl7aoRlCMMec9npubq6qqqguep0+fPnr55Zej+WgAAHAJ4Vk8AADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1Wv00YwCx03f2mnhXAQCswggKAACwDgEFAABYhykeIEEwDQTgUsIICgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsw7N4AFxSmj/T6KO5xXGqCYDzYQQFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKzDre4BXNKa3/pe4vb3gA0YQQEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWCeqgFJRUaHhw4crPT1dWVlZuuWWW1RbWxtR5uTJk/L7/erRo4e6d++ukpIS1dXVRZTZt2+fiouL1bVrV2VlZenee+9VQ0ND21sDAAA6hKhu1FZVVSW/36/hw4eroaFB999/v8aOHatdu3apW7dukqRZs2ZpzZo1eu655+TxeFRaWqpbb71V//jHPyRJjY2NKi4ultfr1RtvvKGDBw9q8uTJSk1N1WOPPRb7FgIWaunmYACA/5dkjDGtffMnn3yirKwsVVVV6frrr1d9fb169uyp5cuX67bbbpMk7d69W4MGDVJ1dbVGjhypV155Rd/61rd04MABZWdnS5IWL16s++67T5988onS0tIu+LnBYFAej0f19fVyu92trT4QNwQUu3EnWcAZ0fz+btOt7uvr6yVJmZmZkqSamhqdPn1ahYWF4TIDBw5UXl5eOKBUV1dryJAh4XAiSUVFRZoxY4Z27typa6655qzPCYVCCoVCEQ0EEgVhBACi1+pFsk1NTZo5c6auu+46DR48WJIUCASUlpamjIyMiLLZ2dkKBALhMl8MJ2eOnznWkoqKCnk8nvCWm5vb2moDAIAE0OqA4vf7tWPHDq1YsSKW9WlReXm56uvrw9v+/fsd/0wAABA/rZriKS0t1erVq7V582b17t07vN/r9erUqVM6cuRIxChKXV2dvF5vuMzWrVsjznfmKp8zZZpzuVxyuVytqSoAAEhAUY2gGGNUWlqqlStXasOGDerXr1/E8WHDhik1NVWVlZXhfbW1tdq3b598Pp8kyefzafv27Tp06FC4zPr16+V2u5Wfn9+WtgAAgA4iqhEUv9+v5cuX64UXXlB6enp4zYjH41GXLl3k8Xg0depUlZWVKTMzU263W/fcc498Pp9GjhwpSRo7dqzy8/N155136vHHH1cgENCcOXPk9/sZJQEAAJKiDCiLFi2SJI0ePTpi/5IlS3TXXXdJkubNm6fk5GSVlJQoFAqpqKhICxcuDJdNSUnR6tWrNWPGDPl8PnXr1k1TpkzRI4880raWAACADqNN90GJF+6DgkTCZcaJh/ugAM5ot/ugADgbgQQA2o6HBQIAAOswggIAzTQfBWPKB2h/jKAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDo8iwdoA55cfGlo6e+Z5/MAzmIEBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANbpFO8KAEBH1Xf2mojXH80tjlNNgMTDCAoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDpRB5TNmzfrpptuUk5OjpKSkrRq1aqI43fddZeSkpIitnHjxkWUOXz4sCZNmiS3262MjAxNnTpVx44da1NDAABAxxF1QDl+/LiuvvpqPfnkk+csM27cOB08eDC8/fWvf404PmnSJO3cuVPr16/X6tWrtXnzZk2fPj362gMAgA4p6jvJjh8/XuPHjz9vGZfLJa/X2+Kxf/3rX1q7dq3eeustXXvttZKkBQsWaMKECfrVr36lnJycaKsEAAA6GEfWoGzatElZWVkaMGCAZsyYoc8++yx8rLq6WhkZGeFwIkmFhYVKTk7Wli1bWjxfKBRSMBiM2AAAQMcV84Aybtw4/fnPf1ZlZaV+8YtfqKqqSuPHj1djY6MkKRAIKCsrK+I9nTp1UmZmpgKBQIvnrKiokMfjCW+5ubmxrjYAALBIzB8WOHHixPCfhwwZoqFDh6p///7atGmTxowZ06pzlpeXq6ysLPw6GAwSUgAA6MAcv8z48ssv12WXXaYPPvhAkuT1enXo0KGIMg0NDTp8+PA51624XC653e6IDQAAdFyOB5SPP/5Yn332mXr16iVJ8vl8OnLkiGpqasJlNmzYoKamJhUUFDhdHQAAkACinuI5duxYeDREkvbu3att27YpMzNTmZmZevjhh1VSUiKv16s9e/bopz/9qa644goVFRVJkgYNGqRx48Zp2rRpWrx4sU6fPq3S0lJNnDiRK3hgvb6z18S7CgBwSYh6BOXtt9/WNddco2uuuUaSVFZWpmuuuUYPPPCAUlJS9O677+rb3/62vvzlL2vq1KkaNmyY/v73v8vlcoXPsWzZMg0cOFBjxozRhAkTNGrUKP3+97+PXasAAEBCi3oEZfTo0TLGnPP4unXrLniOzMxMLV++PNqPBgAAlwiexQMAAKwT88uMgY6C9SYAED+MoAAAAOswggIArcAIG+AsAgrwf/iFAwD2YIoHAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDqd4l0BALhU9J295qx9H80tjkNNAPsxggIAAKxDQAEAANZhigeXpJaG2gEA9mAEBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOlxmDABx1PySd+4sC3yOERQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWCfqgLJ582bddNNNysnJUVJSklatWhVx3BijBx54QL169VKXLl1UWFio999/P6LM4cOHNWnSJLndbmVkZGjq1Kk6duxYmxoCnNF39pqzNgBAYok6oBw/flxXX321nnzyyRaPP/7443riiSe0ePFibdmyRd26dVNRUZFOnjwZLjNp0iTt3LlT69ev1+rVq7V582ZNnz699a0AAAAdStRPMx4/frzGjx/f4jFjjObPn685c+bo5ptvliT9+c9/VnZ2tlatWqWJEyfqX//6l9auXau33npL1157rSRpwYIFmjBhgn71q18pJyenDc0BAAAdQUzXoOzdu1eBQECFhYXhfR6PRwUFBaqurpYkVVdXKyMjIxxOJKmwsFDJycnasmVLLKsDAAASVNQjKOcTCAQkSdnZ2RH7s7Ozw8cCgYCysrIiK9GpkzIzM8NlmguFQgqFQuHXwWAwltUGAACWSYireCoqKuTxeMJbbm5uvKsEAAAcFNOA4vV6JUl1dXUR++vq6sLHvF6vDh06FHG8oaFBhw8fDpdprry8XPX19eFt//79saw2AACwTEwDSr9+/eT1elVZWRneFwwGtWXLFvl8PkmSz+fTkSNHVFNTEy6zYcMGNTU1qaCgoMXzulwuud3uiA0AAHRcUa9BOXbsmD744IPw671792rbtm3KzMxUXl6eZs6cqf/5n//RlVdeqX79+unnP/+5cnJydMstt0iSBg0apHHjxmnatGlavHixTp8+rdLSUk2cOJEreAAAgKRWBJS3335b3/jGN8Kvy8rKJElTpkzR0qVL9dOf/lTHjx/X9OnTdeTIEY0aNUpr165V586dw+9ZtmyZSktLNWbMGCUnJ6ukpERPPPFEDJoDAAA6giRjjIl3JaIVDAbl8XhUX1/PdA/O0tKdYz+aW3zBMoANmv+sAh1JNL+/E+IqHgAAcGmJ6X1QAFsxYgIAiYURFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA63CZMQBYrvll8tzMDZcCRlAAAIB1CCgAAMA6BBQAAGAd1qAAgEV4LAPwOUZQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh6t4AKAD4G6z6GgYQQEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdXgWDxJK8+eNSDxzBAA6IkZQAACAdQgoAADAOgQUAABgHdagIOG1tC4FAJDYGEEBAADWYQQFABLMxYwacsUbEh0BBdbgCxUAcAZTPAAAwDqMoMBqLIAFgEsTIygAAMA6BBQAAGCdmAeUhx56SElJSRHbwIEDw8dPnjwpv9+vHj16qHv37iopKVFdXV2sqwEAABKYIyMoV111lQ4ePBjeXn/99fCxWbNm6aWXXtJzzz2nqqoqHThwQLfeeqsT1QAAAAnKkUWynTp1ktfrPWt/fX29nnnmGS1fvlw33nijJGnJkiUaNGiQ3nzzTY0cOdKJ6gAAgATjyAjK+++/r5ycHF1++eWaNGmS9u3bJ0mqqanR6dOnVVhYGC47cOBA5eXlqbq6+pznC4VCCgaDERsAAOi4Yh5QCgoKtHTpUq1du1aLFi3S3r179fWvf11Hjx5VIBBQWlqaMjIyIt6TnZ2tQCBwznNWVFTI4/GEt9zc3FhXGwAAWCTmUzzjx48P/3no0KEqKChQnz599Oyzz6pLly6tOmd5ebnKysrCr4PBICEFAIAOzPHLjDMyMvTlL39ZH3zwgbxer06dOqUjR45ElKmrq2txzcoZLpdLbrc7YgMAAB2X4wHl2LFj2rNnj3r16qVhw4YpNTVVlZWV4eO1tbXat2+ffD6f01UBAAAJIuZTPD/5yU900003qU+fPjpw4IAefPBBpaSk6I477pDH49HUqVNVVlamzMxMud1u3XPPPfL5fFzBAwAAwmIeUD7++GPdcccd+uyzz9SzZ0+NGjVKb775pnr27ClJmjdvnpKTk1VSUqJQKKSioiItXLgw1tUAAAAJLMkYY+JdiWgFg0F5PB7V19ezHqUD4cGAgLM+mlsc7yrgEhfN72+exQMAAKxDQAEAANZx5Fb3AAD7Xcy0KtNCiBdGUAAAgHUYQUG7YAEsACAaBBQAuETwHwUkEqZ4AACAdQgoAADAOkzxAADOqaVpIa7sQXtgBAUAAFiHgAIAAKzDFA8cwdUCAIC2YAQFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKzDre4BAFFp/igLnm4MJzCCAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOlzFgwiszgcQrebfG9LZ3x0XUwb4IgIKzqulL5Xm+JIBAMQaAQUAEHMX858b4HwIKGgzvogAALHGIlkAAGAdAgoAALAOAQUAAFiHNSiXENaKALAJtzXA+TCCAgAArMMICgDACtzMDV9EQOmgmM4B0BEwDXTpYooHAABYhxGUBMD/IADgc0wDXToYQQEAANaJ6wjKk08+qV/+8pcKBAK6+uqrtWDBAo0YMSKeVWp3sRodYc0JgEuVU9+jjMzEV9wCyt/+9jeVlZVp8eLFKigo0Pz581VUVKTa2lplZWXFq1qOilWIIIwAQHT43kw8ScYYE48PLigo0PDhw/W73/1OktTU1KTc3Fzdc889mj179nnfGwwG5fF4VF9fL7fb3R7VPUtrkjb/QAAgsTGq0jbR/P6OywjKqVOnVFNTo/Ly8vC+5ORkFRYWqrq6+qzyoVBIoVAo/Lq+vl7S5w11wuAH10W83vFw0VllmkInIl5fTF2avwcAkFic+r0Tb81/70kt/+5rqzP9dzFjI3EJKJ9++qkaGxuVnZ0dsT87O1u7d+8+q3xFRYUefvjhs/bn5uY6Vscv8syPTRkAQGK7lL7rnWzr0aNH5fF4zlsmIS4zLi8vV1lZWfh1U1OTDh8+rB49eigpKSmONTtbMBhUbm6u9u/fH7fpJ1vRN+dH/5wf/XNu9M350T/n1t59Y4zR0aNHlZOTc8GycQkol112mVJSUlRXVxexv66uTl6v96zyLpdLLpcrYl9GRoaTVWwzt9vNP4RzoG/Oj/45P/rn3Oib86N/zq09++ZCIydnxOU+KGlpaRo2bJgqKyvD+5qamlRZWSmfzxePKgEAAIvEbYqnrKxMU6ZM0bXXXqsRI0Zo/vz5On78uL7//e/Hq0oAAMAScQsot99+uz755BM98MADCgQC+spXvqK1a9eetXA20bhcLj344INnTUmBvrkQ+uf86J9zo2/Oj/45N5v7Jm73QQEAADgXnsUDAACsQ0ABAADWIaAAAADrEFAAAIB1CCgX8OSTT6pv377q3LmzCgoKtHXr1nOWHT16tJKSks7aiov//+FSx44dU2lpqXr37q0uXbooPz9fixcvbo+mOCLW/VNXV6e77rpLOTk56tq1q8aNG6f333+/PZoSc9H0jSTNnz9fAwYMUJcuXZSbm6tZs2bp5MmTbTqnzWLdP5s3b9ZNN92knJwcJSUladWqVQ63wFmx7p+KigoNHz5c6enpysrK0i233KLa2lqnm+GIWPfNokWLNHTo0PDNynw+n1555RWnm+EYJ757zpg7d66SkpI0c+ZMB2rejME5rVixwqSlpZk//vGPZufOnWbatGkmIyPD1NXVtVj+s88+MwcPHgxvO3bsMCkpKWbJkiXhMtOmTTP9+/c3GzduNHv37jVPPfWUSUlJMS+88EI7tSp2Yt0/TU1NZuTIkebrX/+62bp1q9m9e7eZPn26ycvLM8eOHWvHlrVdtH2zbNky43K5zLJly8zevXvNunXrTK9evcysWbNafU6bOdE/L7/8svnZz35mnn/+eSPJrFy5sp1aE3tO9E9RUZFZsmSJ2bFjh9m2bZuZMGEC/7b+z4svvmjWrFlj3nvvPVNbW2vuv/9+k5qaanbs2NFezYoZJ/rnjK1bt5q+ffuaoUOHmh/96EcOt8QYAsp5jBgxwvj9/vDrxsZGk5OTYyoqKi7q/fPmzTPp6ekRXwBXXXWVeeSRRyLKffWrXzU/+9nPYlPpdhTr/qmtrTWSIr4UGhsbTc+ePc3TTz8d28o7LNq+8fv95sYbb4zYV1ZWZq677rpWn9NmTvTPFyV6QHG6f4wx5tChQ0aSqaqqik2l20l79I0xxnzpS18yf/jDH9pe4XbmVP8cPXrUXHnllWb9+vXmhhtuaJeAwhTPOZw6dUo1NTUqLCwM70tOTlZhYaGqq6sv6hzPPPOMJk6cqG7duoX3fe1rX9OLL76o//znPzLGaOPGjXrvvfc0duzYmLfBSU70TygUkiR17tw54pwul0uvv/56DGvvrNb0zde+9jXV1NSEh2I//PBDvfzyy5owYUKrz2krJ/qnI2mv/qmvr5ckZWZmxrD2zmqPvmlsbNSKFSt0/PjxhHv0ipP94/f7VVxcHHFupyXE04zj4dNPP1VjY+NZd7bNzs7W7t27L/j+rVu3aseOHXrmmWci9i9YsEDTp09X79691alTJyUnJ+vpp5/W9ddfH9P6O82J/hk4cKDy8vJUXl6up556St26ddO8efP08ccf6+DBgzFvg1Na0zff/e539emnn2rUqFEyxqihoUF333237r///laf01ZO9E9H0h7909TUpJkzZ+q6667T4MGDY94GpzjZN9u3b5fP59PJkyfVvXt3rVy5Uvn5+Y61xQlO9c+KFSv0z3/+U2+99Zaj9W+OERSHPPPMMxoyZIhGjBgRsX/BggV688039eKLL6qmpka//vWv5ff79dprr8WppvHRUv+kpqbq+eef13vvvafMzEx17dpVGzdu1Pjx45Wc3LF/VDdt2qTHHntMCxcu1D//+U89//zzWrNmjR599NF4V80K9M/5Rds/fr9fO3bs0IoVK9q5pu3vYvtmwIAB2rZtm7Zs2aIZM2ZoypQp2rVrV5xq3X4u1D/79+/Xj370Iy1btixidLtdOD6JlKBCoZBJSUk5ax578uTJ5tvf/vZ533vs2DHjdrvN/PnzI/afOHHCpKammtWrV0fsnzp1qikqKopJvduLE/3zRUeOHDGHDh0yxnw+p/rDH/6wzXVuL63pm1GjRpmf/OQnEfv+8pe/mC5dupjGxsY29bdtnOif5pTAa1Cc7h+/32969+5tPvzww5jWuz20x8/OGWPGjDHTp09vc53bkxP9s3LlSiPJpKSkhDdJJikpyaSkpJiGhganmsMalHNJS0vTsGHDVFlZGd7X1NSkysrKC85LPvfccwqFQvre974Xsf/06dM6ffr0WaMBKSkpampqil3l24ET/fNFHo9HPXv21Pvvv6+3335bN998c8zq7rTW9M2JEyda/LmQJGNMm/rbNk70T0fiVP8YY1RaWqqVK1dqw4YN6tevn0MtcE57/uw0NTWF18UlCif6Z8yYMdq+fbu2bdsW3q699lpNmjRJ27ZtC5d1hGPRpwNYsWKFcblcZunSpWbXrl1m+vTpJiMjwwQCAWOMMXfeeaeZPXv2We8bNWqUuf3221s85w033GCuuuoqs3HjRvPhhx+aJUuWmM6dO5uFCxc62hYnONE/zz77rNm4caPZs2ePWbVqlenTp4+59dZbHW2HE6LtmwcffNCkp6ebv/71r+bDDz80r776qunfv7/5zne+c9HnTCRO9M/Ro0fNO++8Y9555x0jyfzmN78x77zzjvn3v//d7u1rKyf6Z8aMGcbj8ZhNmzZFXO5/4sSJdm9fWzjRN7NnzzZVVVVm79695t133zWzZ882SUlJ5tVXX2339rWVE/3TXHtdxUNAuYAFCxaYvLw8k5aWZkaMGGHefPPN8LEbbrjBTJkyJaL87t27jaRz/mAfPHjQ3HXXXSYnJ8d07tzZDBgwwPz61782TU1NTjbDMbHun9/+9remd+/eJjU11eTl5Zk5c+aYUCjkZBMcE03fnD592jz00EOmf//+pnPnziY3N9f88Ic/NP/9738v+pyJJtb9s3HjRiPprK35z2CiiHX/tNQ3kiLu05QoYt03P/jBD0yfPn1MWlqa6dmzpxkzZkxChpMznPju+aL2CihJxnSw8VEAAJDwWIMCAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHX+F4kbcfz4bHZNAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "delta_t = np.mean(delta_beta, axis=0)\n", + "_ = plt.hist(delta_t, bins=100)\n", + "# _ = plt.hist(delta_t)\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/single-core-regen/tolerance_testing.py b/src/single-core-regen/tolerance_testing.py index 92b6329..c0cc058 100644 --- a/src/single-core-regen/tolerance_testing.py +++ b/src/single-core-regen/tolerance_testing.py @@ -7,7 +7,10 @@ tests a given model for tolerance against variations in CD, PMD, baudrate need different datasets, osnr is modeled as awgn added to the data before feeding into the model """ +import copy from datetime import datetime +import hashlib +import sys from typing import Literal from matplotlib import pyplot as plt import numpy as np @@ -23,6 +26,70 @@ from signal_gen.generate_signal import single_run, get_config import json +from util.mpl import eyediagram + + +from functools import lru_cache, wraps + +class YetAnotherWrapper: + def __init__(self, x: np.array) -> None: + self.values = x + # here you can use your own hashing function + self.h = hashlib.sha224(np.ascontiguousarray(x).data.tobytes()).hexdigest() + + + def __hash__(self) -> int: + return hash(self.h) + + def __eq__(self, __value: object) -> bool: + return __value.h == self.h + +def memoizer(expensive_function): + @lru_cache() + def cached_wrapper(shell): + return expensive_function(shell.values) + + @wraps(expensive_function) + def wrapper(x: np.array): + shell = YetAnotherWrapper(x) + return cached_wrapper(shell) + + return wrapper + +# def np_cache(function): +# @lru_cache() +# def cached_wrapper(hashable_array): +# array = np.array(hashable_array) +# return function(array) + +# @wraps(function) +# def wrapper(array): + +# # copy lru_cache attributes over too +# wrapper.cache_info = cached_wrapper.cache_info +# wrapper.cache_clear = cached_wrapper.cache_clear + +# return wrapper + + +def ticks_format(value, index): + # Francesco Montesano (CC BY-SA 3.0) https://stackoverflow.com/a/17209836 + """ + get the value and returns the value as: + integer: [0,99] + 1 digit float: [0.1, 0.99] + n*10^m: otherwise + To have all the number of the same size they are all returned as latex strings + """ + exp = np.floor(np.log10(value)) + base = value / 10**exp + if exp == 0 or exp == 1: + return f"{int(value):d}" + if exp == -1: + return f"{value:.1f}" + else: + return f"{int(base):d}e{int(exp):+d}" + class NestedParameterIterator: def __init__(self, parameters): @@ -74,6 +141,7 @@ class NestedParameterIterator: class model_runner: def __init__( self, + results_path: str | Path | None = None, # length_range: tuple[int | float] = (50e3, 50e3), # length_steps: int = 1, # length_log: bool = False, @@ -90,6 +158,8 @@ class model_runner: config: str = "signal_generation.ini", config_dir: str = None, debug: bool = False, + model_path: str | None = None, + conf_from_model:bool=True, ): """ length_range: lower and upper limit of length, in meters @@ -103,23 +173,6 @@ class model_runner: results_dir: directory to save results model_dir: directory containing models """ - self.debug = debug - - self.parameters = {} - self.iter = None - - # self.update_length_range(length_range, length_steps, length_log) - # self.update_baudrate_range(baudrate_range, baudrate_steps, baudrate_log) - # self.update_osnr_range(osnr_range, osnr_steps, osnr_log) - - # self.data_dir = Path(dataset_dir) - # self.data_datetime_glob = dataset_datetime_glob - self.results_dir = Path(results_dir) - # self.model_dir = Path(model_dir) - - config_dir = config_dir or Path(__file__).parent - self.config = config_dir / config - torch.serialization.add_safe_globals([ *util.complexNN.__all__, GlobalSettings, @@ -131,23 +184,44 @@ class model_runner: torch.nn.utils.parametrizations.orthogonal, ]) - self.load_model() + self.debug = debug - self.datasets = [] + self.parameters = {} + self.iter = None + self.config = None + self.model_path = None + + self.results_dir = Path(results_dir) + + self.results_map = {} # contains parameter info, dataset path and results path + self.results_path = None + + if model_path is not None and conf_from_model: + self.config = self.get_config_from_model(model_path) + self.model_path = model_path + if self.config is None: + config_dir = config_dir or Path(__file__).parent + self.config = config_dir / config + + if results_path is not None: + self.results_map = self.load_results_map(results_path) + self.results_path = results_path # def register_parameter(self, name, config): # self.parameters.append({"name": name, "config": config}) def load_results_from_file(self, path): - data, meta = self.load_from_file(path) + data, additional, meta = self.load_from_file(path) self.results = [d.decode() for d in data] self.parameters = meta["parameters"] + return additional ... def load_datasets_from_file(self, path): - data, meta = self.load_from_file(path) + data, additional, meta = self.load_from_file(path) self.datasets = [d.decode() for d in data] self.parameters = meta["parameters"] + return additional ... def update_parameter_range(self, name, config, range, steps, log): @@ -158,65 +232,58 @@ class model_runner: raise ValueError("No parameters registered") self.iter = NestedParameterIterator(self.parameters) - def generate_datasets(self): + def generate_datasets(self, print_run_saves=False, relative:bool = True): # get base config - config = get_config(self.config) + original_config = get_config(self.config) if self.iter is None: self.generate_iterations() - for params in self.iter: + n_comb = self.iter.length + for i, params in enumerate(self.iter): + config = copy.deepcopy(original_config) + temp = {"config": {}} current_settings = [] # params is a list of dictionaries with keys "name", containing a dict with keys "value", "config" for param in params: for name, settings in param.items(): - current_settings.append({name: settings["value"]}) - self.nested_set(config, settings["config"], settings["value"]) - settings_strs = [] - for setting in current_settings: - name = list(setting)[0] - settings_strs.append(f"{name}: {float(setting[name]):.2e}") - settings_str = ", ".join(settings_strs) - print(f"Generating dataset for [{settings_str}]") - # TODO look for existing datasets - _, _, path = single_run(config) - self.datasets.append(str(path)) + value = self.nested_set(config, settings["config"], settings["value"], relative=relative) + if len(self.parameters[name]["range"]) > 1: + self.nested_set(temp["config"], settings["config"], value, relative=False) + current_settings.append((name, value, settings["value"])) + settings_str = self.show_parameter_info(current_settings, False) + print(f"({i + 1}/{n_comb}): {settings_str}", end="", flush=True) + self.results_map[settings_str] = temp - datasets_list_path = self.build_path("datasets_list", parent_dir=self.results_dir, timestamp="back") - metadata = {"parameters": self.parameters} - data = np.array(self.datasets, dtype="S") - self.save_to_file(datasets_list_path, data, **metadata) + _, _, path = single_run(config, silent=not print_run_saves) + print(f" -> {path}") + self.results_map[settings_str]["dataset"] = str(path) + # self.datasets.append(str(path)) + + self.save_results_map("Datasets") + + def show_parameter_info(self, current_settings, show=True): + settings_strs = [] + for setting in current_settings: + name, effective_value, value = setting + + settings_strs.append(f"{name}: {float(effective_value):.3e} ({float(value):+.3e})") + settings_str = ", ".join(settings_strs) + if show: + print(settings_str) + return settings_str @staticmethod - def nested_set(dic, keys, value): + def nested_set(dic: dict, keys: tuple, value: float, relative: bool = True): for key in keys[:-1]: dic = dic.setdefault(key, {}) - dic[keys[-1]] = value + if relative: + dic[keys[-1]] = float(dic[keys[-1]]) + value + else: + dic[keys[-1]] = value + return dic[keys[-1]] - ## Dataset and model loading - # def find_datasets(self, data_dir=None, data_datetime_glob=None): - # # date-time-sps-nos-length-gamma-alpha-D-S-PAM4-birefsteps-deltabeta-symbolrate.ini - # data_dir = data_dir or self.data_dir - # data_datetime_glob = data_datetime_glob or self.data_datetime_glob - # self.datasets = {} - # data_dir = Path(data_dir) - # for length in self.lengths: - # for baudrate in self.baudrates: - # # dataset_glob = self.data_datetime_glob + f"*-*-{int(length)}-*-*-*-*-PAM4-*-*-{int(baudrate/1e9)}.ini" - # dataset_glob = data_datetime_glob + f"-*-*-{int(length)}-*-*-*-*-PAM4-*-*.ini" - # datasets = [f for f in data_dir.glob(dataset_glob)] - # if len(datasets) == 0: - # continue - # self.datasets[length] = {} - # if len(datasets) > 1: - # print( - # f"multiple datasets found for [{length / 1000:.1f} km, {int(baudrate / 1e9)} GBd]. Using the newest dataset." - # ) - # # get newest file from creation date - # datasets.sort(key=lambda x: x.stat().st_ctime) - # self.datasets[length][baudrate] = str(datasets[-1]) - - def load_dataset(self, dataset_path): + def load_dataset(self, dataset_path, angle_variance=0): if self.checkpoint_dict is None: raise ValueError("Model must be loaded before dataset") @@ -228,7 +295,10 @@ class model_runner: symbols = self.checkpoint_dict["settings"]["data_settings"].symbols data_size = self.checkpoint_dict["settings"]["data_settings"].output_size dtype = getattr(torch, self.checkpoint_dict["settings"]["data_settings"].dtype) - drop_first = self.checkpoint_dict["settings"]["data_settings"].drop_first + # drop_first = self.checkpoint_dict["settings"]["data_settings"].drop_first + # drop_last = self.checkpoint_dict["settings"]["data_settings"].drop_last + drop_first = 64 + drop_last = 64 randomise_polarisations = self.checkpoint_dict["settings"]["data_settings"].randomise_polarisations polarisations = self.checkpoint_dict["settings"]["data_settings"].polarisations num_symbols = None @@ -242,11 +312,13 @@ class model_runner: symbols=symbols, output_dim=data_size, drop_first=drop_first, + drop_last=drop_last, dtype=dtype, real=not dtype.is_complex, randomise_polarisations=randomise_polarisations, polarisations=polarisations, num_symbols=num_symbols, + cross_pol_interference=angle_variance, # device="cuda" if torch.cuda.is_available() else "cpu", ) @@ -258,6 +330,15 @@ class model_runner: # run model # return results as array: [fiber_in, fiber_out, fiber_out_noisy, regen_out] + def get_config_from_model(self, model_path): + try: + checkpoint_dict = torch.load(model_path, weights_only=True) + config_path = checkpoint_dict["settings"]["data_settings"].config_path + print("Base config loaded from model") + except (FileNotFoundError, KeyError, TypeError): + config_path = None + return config_path + def load_model(self, model_path: str | None = None): if model_path is None: self.model = None @@ -280,49 +361,47 @@ class model_runner: self.model.load_state_dict(self.checkpoint_dict["model_state_dict"]) ## Model evaluation - def run_model_evaluation(self, model_path: str, datasets: str | None = None): - self.load_model(model_path) - # iterate over datasets and osnr values: - # load dataset, add noise, run model, return results - # save results to file - self.results = [] - - if datasets is not None: - self.load_datasets_from_file(datasets) + def run_model_evaluations(self, force: bool = False): + mpath = Path(self.model_path) + model_base = mpath.stem - n_datasets = len(self.datasets) - for i, dataset_path in enumerate(self.datasets): - conf = get_config(dataset_path) - mpath = Path(model_path) - model_base = mpath.stem - print(f"({1+i}/{n_datasets}) Running model {model_base} with dataset {dataset_path.split('/')[-1]}") + self.load_model(self.model_path) + print() + print(f"Running model {self.model_path}") - results_path = self.build_path( - dataset_path.split("/")[-1], parent_dir=Path(self.results_dir) / model_base + n_datasets = len(self.results_map) + for i, (key, it) in enumerate(self.results_map.items()): + if not force: + if "regen" in it and it.get("model_path", None) == model_path: + # already run + continue + + print(f"({1 + i}/{n_datasets}): {key}", end="") + + dataset_path = it["dataset"] + regen_path = self.build_path( + ".".join(dataset_path.split("/")[-1].split(".")[:-1]), parent_dir=Path(self.results_dir) / model_base ) - orig_symbols = self.load_dataset(dataset_path) - + conf = get_config(dataset_path) + orig_symbols = np.swapaxes(self.load_dataset(dataset_path), 0, 1) data, loss = self.run_model() - metadata = { - "model_path": model_path, - "dataset_path": dataset_path, - "loss": loss, - "sps": conf["glova"]["sps"], - "orig_symbols": orig_symbols - # "config": conf, - # "checkpoint_dict": self.checkpoint_dict, - # "nos": self.dataloader.dataset.num_symbols, - } + self.save_to_file(regen_path, data, additional_data={"orig_symbols": orig_symbols}) + print(f" -> {regen_path}") - self.save_to_file(results_path, data, **metadata) - self.results.append(str(results_path)) + self.results_map[key]["regen"] = str(regen_path) + self.results_map[key]["model_path"] = str(self.model_path) + self.results_map[key]["model_basename"] = model_base + self.results_map[key]["sps"] = conf["glova"]["sps"] + self.results_map[key]["run_metadata"] = {"loss": loss} - results_list_path = self.build_path("results_list", parent_dir=self.results_dir, timestamp="back") - metadata = {"parameters": self.parameters} - data = np.array(self.results, dtype="S") - self.save_to_file(results_list_path, data, **metadata) + self.save_results_map("Results") + # results_list_path = self.build_path("results_list", parent_dir=self.results_dir, timestamp="back") + # metadata = {"parameters": self.parameters} + # data = np.array(self.results, dtype="S") + # self.save_to_file(results_list_path, data, metadata) + # print(f"Saved results list to file {results_list_path}") def run_model(self): loss = 0 @@ -332,8 +411,8 @@ class model_runner: model = self.model.to("cuda" if torch.cuda.is_available() else "cpu") with torch.no_grad(): for batch in self.dataloader: - x = batch["x_stacked"] - y = batch["y_stacked"] + x = batch["x"] + y = batch["y"] fiber_in = batch["plot_target"] # fiber_out = batch["plot_clean"] fiber_out = batch["plot_data"] @@ -366,38 +445,72 @@ class model_runner: return data_out, loss ## File I/O + + def save_results_map(self, save_type: str | None = None): + self.results_path = self.results_path or self.build_path( + "results_map", parent_dir=self.results_dir, timestamp="back", filetype="json" + ) + Path(self.results_path).parent.mkdir(parents=True, exist_ok=True) + with open(self.results_path, "w") as f: + json.dump(self.convert_arrays(self.results_map), f, indent=2) + save_type = save_type or "Results Map" + print(f"{save_type} saved to {self.results_path}") + @staticmethod - def save_to_file(path: str, data: np.ndarray, **metadata: dict): + def load_results_map(path: str | Path | None): + # path = path or self.results_path + if path is None: + raise ValueError("No path specified") + path = Path(path) + # self.results_path = path + with open(path, "r") as f: + results_map = json.load(f) + return results_map + + @staticmethod + def save_to_file(path: str, data: np.ndarray, metadata: dict | None = None, *, additional_data: dict | None = None): # create directory if it doesn't exist path.parent.mkdir(parents=True, exist_ok=True) with h5py.File(path, "w") as outfile: outfile.create_dataset("data", data=data) + if additional_data is not None: + if metadata is None: + metadata = {} + metadata["__datasets"] = [] + for name, dat in additional_data.items(): + if name == "data": + name = "data_1" + metadata["__datasets"].append(name) + outfile.create_dataset(name, data=dat) for key, value in metadata.items(): if isinstance(value, dict): value = json.dumps(model_runner.convert_arrays(value)) outfile.attrs[key] = value @staticmethod - def convert_arrays(dict_in): + def convert_arrays(input_object): """ convert ndarrays in (nested) dict to lists """ - dict_out = {} - for key, value in dict_in.items(): - if isinstance(value, dict): - dict_out[key] = model_runner.convert_arrays(value) - elif isinstance(value, np.ndarray): - dict_out[key] = value.tolist() - else: - dict_out[key] = value - return dict_out + + if isinstance(input_object, np.ndarray): + return input_object.tolist() + elif isinstance(input_object, list): + return [model_runner.convert_arrays(old) for old in input_object] + elif isinstance(input_object, tuple): + return tuple(model_runner.convert_arrays(old) for old in input_object) + elif isinstance(input_object, dict): + # dict_out = {} + for key, value in input_object.items(): + input_object[key] = model_runner.convert_arrays(value) + return input_object @staticmethod def load_from_file(path: str): with h5py.File(path, "r") as infile: - data = infile["data"][:] metadata = {} + additional_datasets = {} for key in infile.attrs.keys(): if isinstance(infile.attrs[key], (str, bytes, bytearray)): try: @@ -406,13 +519,33 @@ class model_runner: metadata[key] = infile.attrs[key] else: metadata[key] = infile.attrs[key] - return data, metadata + data = infile["data"][:] + if "__datasets" in metadata: + for dataset_name in metadata["__datasets"]: + additional_datasets[dataset_name] = infile[dataset_name][:] + metadata.pop("__datasets") + return data, additional_datasets, metadata ## Utility functions @staticmethod def logrange(start, stop, num, endpoint=False): - lower, upper = np.log10((start, stop)) - return np.logspace(lower, upper, num=num, endpoint=endpoint, base=10) + # offset = 0 + # if start <= 0: + # offset = -start + 1e-9 + # start += offset + # stop += offset + upper = np.log10(stop) + if start == 0: + # lower = upper/(num-2) + lower = upper/(num-2) + num -= 1 + else: + lower = np.log10(start) + # lower, upper = np.log10((start, stop)) + rang = np.logspace(lower, upper, num=num, endpoint=endpoint, base=10) + if start == 0: + rang = np.append((0,), rang) + return rang @staticmethod def build_path( @@ -430,37 +563,56 @@ class model_runner: if parent_dir is not None: path = Path(parent_dir) / path - return path + return Path(path) @staticmethod def update_range(min, max, n_steps, log): + if n_steps == 0: + min = max = 0 + if max < min: + temp = max + max = min + min = temp + if min == max: + n_steps = 1 + log = False if log: range = model_runner.logrange(min, max, n_steps, endpoint=True) else: range = np.linspace(min, max, n_steps, endpoint=True) return range + # def print_dataset_info(self, datasets:tuple = None): + # datasets = datasets or self.datasets + # if len(datasets) == 0: + # print("No datasets loaded") + # return + # n_datasets = len(datasets) + # if n_datasets > 3: + # short_dsets = [datasets[0], "...", datasets[-1]] + # else: + # short_dsets = datasets + # print("datasets: ") + # for dset in short_dsets: + # print("\t" + dset) + # print() + # print(f"{len(datasets)} datasets loaded") + class model_evaluation_result: def __init__( self, - *, - length=None, - baudrate=None, - osnr=None, model_path=None, dataset_path=None, loss=None, sps=None, - **kwargs, + **additional_dataset_metadata, ): - self.length = length - self.baudrate = baudrate - self.osnr = osnr self.model_path = model_path self.dataset_path = dataset_path self.loss = loss self.sps = sps + self.additional_dataset_metadata = additional_dataset_metadata self.sers = None self.bers = None @@ -468,75 +620,226 @@ class model_evaluation_result: class evaluator: - def __init__(self, datasets: list[str]): + def __init__(self, results_path: str, eye_data_dir: str | Path = "tolerance_results/eye_data"): """ datasets: iterable of dataset paths data_dir: directory containing datasets """ - self.datasets = datasets - self.results = [] + self.results_map = model_runner.load_results_map(results_path) + self.results_path = results_path + self.eye_data_dir = eye_data_dir + # self.results = [] - def evaluate_datasets(self, plot=False): + def save_results_map(self, save_type: str | None = None): + self.results_path = Path(self.results_path) or self.build_path( + "results_map", parent_dir=self.results_dir, timestamp="back", filetype="json" + ) + self.results_path.parent.mkdir(parents=True, exist_ok=True) + with open(self.results_path, "w") as f: + json.dump(model_runner.convert_arrays(self.results_map), f, indent=2) + save_type = save_type or "Analyses" + print(f"{save_type} saved to {self.results_path}") + + def evaluate_datasets( + self, + plot=False, + show_pre=True, + show_post=True, + show_SER=True, + verbose=True, + force=False, + ber_threshold=2.4e-2, + just_regen=False, + ): ## iterate over datasets # load dataset - for dataset in self.datasets: - model, dataset_name = dataset.split("/")[-2:] - print(f"\nEvaluating model {model} with dataset {dataset_name}") - data, metadata = model_runner.load_from_file(dataset) - result = model_evaluation_result(**metadata) + n_datasets = len(self.results_map) + print( + f"Evaluating {n_datasets} datasets against model {self.results_map[list(self.results_map.keys())[0]]['model_basename']} with BER threshold {ber_threshold:.3e}" + ) + for i, (k, v) in enumerate(self.results_map.items()): + regen_path = v["regen"] + model = v["model_basename"] - data = self.prepare_data(data, sps=metadata["sps"]) + if just_regen: + print(f"({i + 1}/{n_datasets}): {k}", end="", flush=True) - try: - sym_x, sym_y = metadata["orig_symbols"] - except (TypeError, KeyError, ValueError): - sym_x, sym_y = None, None + title = f"({i + 1}/{n_datasets}): {k} - {model}" + title = "\n" + title + res_str = None - self.evaluate_eye(data, result, title=dataset.split("/")[-1], plot=False) - self.evaluate_ser_ber(data, result, sym_x, sym_y) - print("BER:") - self.print_dict(result.bers["regen"]) - print() - print("SER:") - self.print_dict(result.sers["regen"]) - print() + if "analysis" in self.results_map[k] and not force: + # res_str = self.results_map[k]['analysis']['result_str'] + title += " - results from file" + + title += "\n" + "-" * len(title) + "\n" + + if not just_regen: + print(title) + # if res_str is not None: + # print(res_str) + # continue + if not ("analysis" in self.results_map[k] and not force): + data, additional, metadata = model_runner.load_from_file(regen_path) + # model_path = metadata.pop("model_path") + # dataset_path = metadata.pop("dataset_path") + result = {"metadata": metadata} + + data = self.prepare_data(data, v["sps"]) + + symbols = additional.get("orig_symbols", None) + if symbols is not None: + sym_x, sym_y = symbols + + self.evaluate_eye(data, result, title=k, plot=plot, skip_regen=not show_post) + self.evaluate_ser_ber(data, v["sps"], result, sym_x, sym_y, skip_regen=not show_post) + + for j in range(len(result["eye_stats"])): + try: + del result["eye_stats"][j]["amplitude_clusters"] + except KeyError: + ... + try: + del result["eye_stats"][j]["time_clusters"] + except KeyError: + ... + + self.results_map[k]["analysis"] = result + + else: + result = self.results_map[k]["analysis"] + if plot: + data, additional, metadata = model_runner.load_from_file(regen_path) + data = self.prepare_data(data, v["sps"]) + self.evaluate_eye(data, result, title=k, plot=True, skip_regen=not show_post) + + res_str = self.generate_analysis_report( + show_pre, show_post, show_SER, verbose, ber_threshold, result, just_regen=just_regen + ) + + print(res_str) + + self.save_results_map("Analyses") - self.results.append(result) if plot: plt.show() - - def evaluate_eye(self, data, result, title=None, plot=False): + def generate_analysis_report(self, show_pre, show_post, show_SER, verbose, ber_threshold, result, just_regen=False): + if just_regen: + return f" -> {result['bers']['regen']['combined']['total']:.3e} {'✅' if result['bers']['regen']['combined']['total'] <= ber_threshold else '❌'}" + ber_str = "BER\n" + ser_str = "SER\n" + + if verbose: + ber_str += self.stringify_dict(result["bers"]["fiber_out"], indent=1) + ser_str += self.stringify_dict(result["sers"]["fiber_out"], indent=1) + else: + ber_str += f"combined: {result['bers']['fiber_out']['combined']['total']:.3e}\n" + # ber_str += f"y: {result['bers']['fiber_out']['y']['total']:.3e}\n" + ser_str += f"combined: {result['sers']['fiber_out']['combined']['total']:.3e}\n" + # ser_str += f"y: {result['sers']['fiber_out']['y']['total']:.3e}\n" + + if show_SER: + res_str_fiber = self.combine_multiline_strings(ber_str, ser_str, padding=2) + else: + res_str_fiber = ber_str + + if ( + result["bers"]["fiber_out"]["combined"]["total"] <= ber_threshold + # and result['bers']["fiber_out"]["y"]["total"] <= ber_threshold + ): + res_str_fiber += "\nBelow BER threshold: ✅" + else: + res_str_fiber += "\nBelow BER threshold: ❌" + + res_str_fiber = "Pre regenerator\n" + res_str_fiber + + ber_str = "BER\n" + ser_str = "SER\n" + + if show_post: + if verbose: + ber_str += self.stringify_dict(result["bers"]["regen"], indent=1) + ser_str += self.stringify_dict(result["sers"]["regen"], indent=1) + else: + ber_str += f"combined: {result['bers']['regen']['combined']['total']:.3e}\n" + # ber_str += f"y: {result['bers']['regen']['y']['total']:.3e}\n" + ser_str += f"combined: {result['sers']['regen']['combined']['total']:.3e}\n" + # ser_str += f"y: {result['sers']['regen']['y']['total']:.3e}\n" + + if show_SER: + res_str_regen = self.combine_multiline_strings(ber_str, ser_str, padding=2) + else: + res_str_regen = ber_str + + if ( + result["bers"]["regen"]["combined"]["total"] <= ber_threshold + # and result['bers']["regen"]["y"]["total"] <= ber_threshold + ): + res_str_regen += "\nBelow BER threshold: ✅" + else: + res_str_regen += "\nBelow BER threshold: ❌" + + res_str_regen = "Post regenerator\n" + res_str_regen + + else: + res_str_regen = "" + + if show_pre and show_post: + res_str = self.combine_multiline_strings(res_str_fiber, res_str_regen, padding=4) + elif show_pre: + res_str = res_str_fiber + elif show_post: + res_str = res_str_regen + return res_str + + def evaluate_eye(self, data, result: dict, title=None, plot=False, skip_regen=False): + Path(self.eye_data_dir).mkdir(parents=True, exist_ok=True) eye = util.eye_diagram.eye_diagram( - data, + data[:4] if skip_regen else data, channel_names=[ "fiber_in_x", "fiber_in_y", - # "fiber_out_x", - # "fiber_out_y", + "fiber_out_x", + "fiber_out_y", + ] + if skip_regen + else [ + "fiber_in_x", + "fiber_in_y", "fiber_out_x", "fiber_out_y", "regen_x", "regen_y", ], + save_file_or_dir=self.eye_data_dir, + horizontal_bins=256, + vertical_bins=1000, ) eye.analyse() - eye.plot(title=title or "Eye diagram", show=plot) + if plot: + eye.plot( + title=title or "Eye diagram", + show=False, + # save_images=True, + # image_dir=f"{self.eye_data_dir}/plots/{self.results_map[list(self.results_map.keys())[0]]['model_basename']}" + ) - result.eye_stats = eye.eye_stats + result["eye_stats"] = eye.eye_stats + result["eye_file"] = eye.save_file return eye.eye_stats ... - def evaluate_ser_ber(self, data, result, sym_x=None, sym_y=None): - if result.eye_stats is None: + def evaluate_ser_ber(self, data, sps, result: dict, sym_x=None, sym_y=None, skip_regen=False): + if "eye_stats" not in result: self.evaluate_eye(data, result) symbols = [] sers = {"fiber_out": {"x": None, "y": None}, "regen": {"x": None, "y": None}} bers = {"fiber_out": {"x": None, "y": None}, "regen": {"x": None, "y": None}} - for channel_data, stats in zip(data, result.eye_stats): + for channel_data, stats in zip(data, result["eye_stats"]): timestamps = channel_data[0] dat = channel_data[1] @@ -546,53 +849,108 @@ class evaluator: time_midpoint = stats["time_midpoint"] else: if channel_name.endswith("x"): - thresholds = result.eye_stats[0]["thresholds"] - time_midpoint = result.eye_stats[0]["time_midpoint"] + thresholds = result["eye_stats"][0]["thresholds"] + time_midpoint = result["eye_stats"][0]["time_midpoint"] elif channel_name.endswith("y"): - thresholds = result.eye_stats[1]["thresholds"] - time_midpoint = result.eye_stats[1]["time_midpoint"] + thresholds = result["eye_stats"][1]["thresholds"] + time_midpoint = result["eye_stats"][1]["time_midpoint"] else: levels = np.linspace(np.min(dat), np.max(dat), 4) thresholds = util.eye_diagram.eye_diagram.calculate_thresholds(levels) time_midpoint = 1.0 - # time_offset = time_midpoint - 0.5 - # # time_offset = 0 + if time_midpoint is None: + time_midpoint = 1.0 - # index_offset = np.argmin(np.abs((timestamps - time_offset) % 1.0)) - - nos = len(timestamps) // result.sps - - # idx = np.arange(index_offset, len(timestamps), result.sps).astype(int) - - # if time_offset < 0: - # idx = np.insert(idx, 0, 0) - - idx = list(range(0,len(timestamps),result.sps)) - - idx = idx[:nos] + sample_offset = int(np.round((time_midpoint - 1.0) * sps, 0)) + idx = list(range(sample_offset, len(timestamps), sps)) data_sampled = dat[idx] + detected_symbols = self.detect_symbols(data_sampled, thresholds) symbols.append({"channel_name": channel_name, "symbols": detected_symbols}) - symbols_x_gt = sym_x or symbols[0]["symbols"] - symbols_y_gt = sym_y or symbols[1]["symbols"] + symbols_x_gt = symbols[0]["symbols"] if sym_x is None else sym_x + symbols_y_gt = symbols[1]["symbols"] if sym_y is None else sym_y symbols_x_fiber_out = symbols[2]["symbols"] symbols_y_fiber_out = symbols[3]["symbols"] - symbols_x_regen = symbols[4]["symbols"] - symbols_y_regen = symbols[5]["symbols"] + if not skip_regen: + symbols_x_regen = symbols[4]["symbols"] + symbols_y_regen = symbols[5]["symbols"] - sers["fiber_out"]["x"], bers["fiber_out"]["x"] = self.calculate_ser_ber(symbols_x_gt, symbols_x_fiber_out) - sers["fiber_out"]["y"], bers["fiber_out"]["y"] = self.calculate_ser_ber(symbols_y_gt, symbols_y_fiber_out) - sers["regen"]["x"], bers["regen"]["x"] = self.calculate_ser_ber(symbols_x_gt, symbols_x_regen) - sers["regen"]["y"], bers["regen"]["y"] = self.calculate_ser_ber(symbols_y_gt, symbols_y_regen) + sers["fiber_out"]["x"], bers["fiber_out"]["x"] = self.calculate_ser_ber( + *self.find_best_matching_subsets(symbols_x_gt, symbols_x_fiber_out) + ) + sers["fiber_out"]["y"], bers["fiber_out"]["y"] = self.calculate_ser_ber( + *self.find_best_matching_subsets(symbols_y_gt, symbols_y_fiber_out) + ) - result.sers = sers - result.bers = bers + sers["fiber_out"]["combined"] = {} + sers["fiber_out"]["combined"]["n_symbols"] = ( + sers["fiber_out"]["x"]["n_symbols"] + sers["fiber_out"]["y"]["n_symbols"] + ) + sers["fiber_out"]["combined"]["n_errors"] = ( + sers["fiber_out"]["x"]["n_errors"] + sers["fiber_out"]["y"]["n_errors"] + ) + sers["fiber_out"]["combined"]["total"] = ( + sers["fiber_out"]["combined"]["n_errors"] / sers["fiber_out"]["combined"]["n_symbols"] + ) + + bers["fiber_out"]["combined"] = {} + bers["fiber_out"]["combined"]["n_bits"] = bers["fiber_out"]["x"]["n_bits"] + bers["fiber_out"]["y"]["n_bits"] + bers["fiber_out"]["combined"]["n_errors"] = ( + bers["fiber_out"]["x"]["n_errors"] + bers["fiber_out"]["y"]["n_errors"] + ) + bers["fiber_out"]["combined"]["total"] = ( + bers["fiber_out"]["combined"]["n_errors"] / bers["fiber_out"]["combined"]["n_bits"] + ) + + if not skip_regen: + sers["regen"]["x"], bers["regen"]["x"] = self.calculate_ser_ber( + *self.find_best_matching_subsets(symbols_x_gt, symbols_x_regen) + ) + sers["regen"]["y"], bers["regen"]["y"] = self.calculate_ser_ber( + *self.find_best_matching_subsets(symbols_y_gt, symbols_y_regen) + ) + + sers["regen"]["combined"] = {} + sers["regen"]["combined"]["n_symbols"] = sers["regen"]["x"]["n_symbols"] + sers["regen"]["y"]["n_symbols"] + sers["regen"]["combined"]["n_errors"] = sers["regen"]["x"]["n_errors"] + sers["regen"]["y"]["n_errors"] + sers["regen"]["combined"]["total"] = ( + sers["regen"]["combined"]["n_errors"] / sers["regen"]["combined"]["n_symbols"] + ) + + bers["regen"]["combined"] = {} + bers["regen"]["combined"]["n_bits"] = bers["regen"]["x"]["n_bits"] + bers["regen"]["y"]["n_bits"] + bers["regen"]["combined"]["n_errors"] = bers["regen"]["x"]["n_errors"] + bers["regen"]["y"]["n_errors"] + bers["regen"]["combined"]["total"] = ( + bers["regen"]["combined"]["n_errors"] / bers["regen"]["combined"]["n_bits"] + ) + + result["sers"] = sers + result["bers"] = bers + + def find_best_matching_subsets(self, a: np.ndarray, b: np.ndarray): + len_a = len(a) + len_b = len(b) + + if len_a == len_b: + return a, b + + if len_a < len_b: + return self.find_best_matching_subsets(b, a) + + diff = len_a - len_b + + errors = np.ones(diff) * np.inf + for i in range(diff): + errors[i] = np.sum(np.abs(np.subtract(a[i : -diff + i], b))) + shift = np.argmin(errors) + + return a[shift : -diff + shift], b @staticmethod def calculate_ser_ber(symbols_gt, symbols): @@ -608,24 +966,30 @@ class evaluator: ser = {} ber = {} ser["n_symbols"] = len(symbols_gt) - ser["n_errors"] = np.sum(symbols != symbols_gt) + ser["n_errors"] = int(np.sum(symbols != symbols_gt)) ser["total"] = float(ser["n_errors"] / ser["n_symbols"]) bec = np.vectorize(bec_map.get)(np.abs(symbols - symbols_gt)) bit_errors = np.sum(bec) ber["n_bits"] = len(symbols_gt) * 2 - ber["n_errors"] = bit_errors + ber["n_errors"] = int(bit_errors) ber["total"] = float(ber["n_errors"] / ber["n_bits"]) return ser, ber @staticmethod - def print_dict(d: dict, indent=2, logarithmic=False, level=0): - for key, value in d.items(): + def print_dict(d: dict, indent=2, logarithmic=False): + print(evaluator.stringify_dict(d, indent, logarithmic)) + + @staticmethod + def stringify_dict(d: dict, indent=2, logarithmic=False, level=0): + dict_str = "" + n = len(d.items()) - 1 + for i, (key, value) in enumerate(d.items()): if isinstance(value, dict): - print(f"{' ' * indent * level}{key}:") - evaluator.print_dict(value, indent=indent, logarithmic=logarithmic, level=level + 1) + dict_str += f"{' ' * indent * level}{key}:\n" + dict_str += evaluator.stringify_dict(value, indent=indent, logarithmic=logarithmic, level=level + 1) else: if isinstance(value, float): if logarithmic: @@ -633,10 +997,35 @@ class evaluator: value = -np.inf else: value = np.log10(value) - print(f"{' ' * indent * level}{key}: {value:.2e}\t", end="") + dict_str += f"{' ' * indent * level}{key}: {value:.3e}" else: - print(f"{' ' * indent * level}{key}: {value}\t", end="") - print() + dict_str += f"{' ' * indent * level}{key}: {value}" + if i < n: + dict_str += "\n" + + return dict_str + + @staticmethod + def combine_multiline_strings(a: str, b: str, padding: int = 1): + padding = max(1, padding) + lines_a = a.split("\n") + lines_b = b.split("\n") + + n_lines_a = len(lines_a) + n_lines_b = len(lines_b) + + n_lines = max(n_lines_a, n_lines_b) + + target_len = padding + max(map(len, lines_a)) + + output_lines = [] + for i in range(n_lines): + item_a = lines_a[i] if i < n_lines_a else "" + item_a += " " * (target_len - len(item_a)) + item_b = lines_b[i] if i < n_lines_b else "" + output_lines.append(item_a + item_b) + + return "\n".join(output_lines) @staticmethod def detect_symbols(samples, thresholds=None): @@ -666,58 +1055,297 @@ class evaluator: data_eye = [] for channel_values in data: channel_values = np.square(np.abs(channel_values)) + # channel_values -= np.min(channel_values) + # channel_values /= np.max(channel_values) data_eye.append(np.stack((timestamps, channel_values), axis=0)) data_eye = np.stack(data_eye, axis=0) return data_eye +def slugify(value, allow_unicode=False): + """ + copright (CC BY-SA 4.0) S.Lott @ https://stackoverflow.com/a/295466 + Taken from https://github.com/django/django/blob/master/django/utils/text.py + Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated + dashes to single dashes. Remove characters that aren't alphanumerics, + underscores, or hyphens. Convert to lowercase. Also strip leading and + trailing whitespace, dashes, and underscores. + """ + import unicodedata + import re -def generate_data(parameters, runner=None): - runner = runner or model_runner() + value = str(value) + if allow_unicode: + value = unicodedata.normalize('NFKC', value) + else: + value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii') + value = re.sub(r'[^\w\s-]', '', value.lower()) + return re.sub(r'[-\s]+', '_', value)#.strip('__') + +# @njit +def calc_ybounds(data): + ymax = np.max(data) + ymin = np.min(data) + yspan = ymax-ymin + ymax += 0.05*yspan + ymin -= 0.05*yspan + ymax = np.ceil(ymax*10)/10 + ymin = np.floor(ymin*10)/10 + return (ymin, ymax) + +# @njit(parallel=True, cache=True) +def power(data): + return np.square(np.abs(data)).real + +# @njit +def slow_op(data): + dat = power(data) + bounds = calc_ybounds(dat) + return dat, bounds + +memo_slow = memoizer(slow_op) + +def just_eyes(results_path, save_dir=None): + results_map = model_runner.load_results_map(results_path) + + save_dir = Path(__file__).parent if save_dir is None else Path(save_dir) + + for k,v in results_map.items(): + # make directory + save_dir_:Path = save_dir / v['model_basename'] + save_dir_.mkdir(exist_ok=True, parents=True) + + # load dataset + data, additional, metadata = model_runner.load_from_file(v['regen']) + + # channel_names = [channel['channel_name'] for channel in v['analysis']['eye_stats']] + + # datas = {name: np.square(np.abs(data[:,i])) for i,name in enumerate(channel_names)} + sps = v['sps'] + + fiber_ins = [False, False] + + for i,channel in enumerate(v['analysis']['eye_stats']): + name = channel['channel_name'] + + dat, bounds = memo_slow(data[:,i]) + bounds = tuple(map(float,bounds)) + file = f"{f"{k}_" if len(k) != 0 else ''}{bounds[0]:.1f}-{bounds[1]:.1f}_{name}" + if name == "fiber_in_x": + file = f"{bounds[0]:.1f}-{bounds[1]:.1f}_{name}" + if fiber_ins[0]: + continue + fiber_ins[0] = True + if name == "fiber_in_y": + file = f"{bounds[0]:.1f}-{bounds[1]:.1f}_{name}" + if fiber_ins[1]: + continue + fiber_ins[1] = True + #generate filename + file = save_dir_ / f"{slugify(file)}.png" + eyediagram(dat, 2*sps, offset=6, bounds=bounds, show=False, save_im=True, save_path=file) + # sys.exit() # debug: only first dataset + ... + +def generate_data(parameters, runner: model_runner, silent=True, relative=True): for param in parameters: runner.update_parameter_range(*param) runner.generate_iterations() print(f"{runner.iter.length} parameter combinations") - runner.generate_datasets() + runner.generate_datasets(print_run_saves=not silent, relative=relative) - return runner + +def is_included(item, fixed): + for kfix, vfix in fixed.items(): + try: + curr = item[kfix[0]] + for key in kfix[1:]: + curr = vfix[key] + if curr != vfix: + return False + except KeyError: + ... + return True + + +# defaults +results_path = None +model_path = None +generate_only = False +fiber_only = False +parameters = (("lambda0", ("glova", "lambda0"), (0, 0), 0, False),) +ber_threshold = 3.8e-3 +relative = True + +def plot_bit_error_ratio(is_included, parameters, ber_threshold, eval): + results_map = eval.results_map + + # plot = ("fiber", "ortho_error") + plot = parameters[0][1] + # fixed = {("fiber", "s"):0} + fixed = {} + + xs = [] + bers = [] + # losses = [] + + for k, v in results_map.items(): + if not is_included(v["config"], fixed): + continue + xs.append(v["config"][plot[0]]) + for key in plot[1:]: + xs[-1] = xs[-1][key] + bers.append(v["analysis"]["bers"]["regen"]["combined"]["total"]) + # losses.append(v['run_metadata']['loss']) + + bers_min = np.min(bers) + exp = np.log10(np.max((bers_min, 1e-5))) + ymin = np.pow(10, np.floor(exp)) + + bers_max = np.max(bers) + exp = np.log10(bers_max) + ymax = np.pow(10, np.ceil(exp)) + + from matplotlib import ticker + + plt.plot(xs, bers, "o--") + plt.axhline(ber_threshold, color="red", linestyle="--") + ax = plt.gca() + ax.set_xticks(xs) + plt.text( + ax.get_xbound()[0] - 0.2, + ber_threshold, + f"{ber_threshold:.2e}", + horizontalalignment="right", + verticalalignment="center", + ) + ax.yaxis.set_major_locator(ticker.LogLocator(subs="all")) + ax.yaxis.set_minor_formatter(ticker.LogFormatter(labelOnlyBase=False, minor_thresholds=(2, 0.4))) + ax.set_yscale("log") + if parameters[0][-1]: + ax.set_xscale("symlog") + ax.set_ylim(ymin, ymax) + ax.yaxis.set_label_position("left") + ax.set_xlabel(parameters[0][0]) + ax.set_ylabel("Bit Error Ratio") + + # ax1 = ax.twiny() + # ax1.yaxis.set_label_position('left') + # ax1.yaxis.set_major_locator(ticker.FixedLocator((ber_threshold,))) + # ax1.set_yscale("log") + # thresh_ax = plt.gca() + # thresh_ax.set_yticks((ber_threshold,)) + # thresh_ax.set_yscale("log") + + # ymin = np.min((plt.gca().get_ylim()[0], 1e-5)) + # ymax = plt.gca().get_ylim()[1] + + # plt.ylim(ymin, ymax) + plt.grid(which="both") if __name__ == "__main__": - model_path = ".models/best_20250110_191149.tar" # D 17, OSNR 100, delta_beta 0.14, baud 10e9 - parameters = ( # name, config keys, (min, max), n_steps, log - # ("D", ("fiber", "d"), (28,30), 3, False), - # ("S", ("fiber", "s"), (0, 0.058), 2, False), - ("OSNR", ("signal", "osnr"), (20, 40), 5, False), - # ("PMD", ("fiber", "max_delta_beta"), (0, 0.28), 3, False), - # ("Baud", ("glova", "symbolrate"), (10e9, 100e9), 3, True), + # ("OSNR", ("signal", "osnr"), (1000, 1000), 1, False), + ("lambda0", ("glova", "lambda0"), (1530, 1565), 11, False), + # ("length", ("fiber", "length"), (-80e3, 80e3), 0, False), + # ("Baud", ("glova", "symbolrate"), (-5e9, 5e9), 0, False), + # ("D", ("fiber", "d"), (-33, -1), 17, False), # fiber_out testing -> D = -16 .. 16 + # ("D", ("fiber", "d"), (-16, 16), 17, False), # regen testing -> D = 1 .. 33 + # ("PMDq", ("fiber", "pmd_q"), (-0.2, -0.2), 1, False), + # ("birefsteps", ("fiber", "birefsteps"), (-999, -999), 1, False), + # ("pol_error", ("fiber", "pol_error"), (0, 0.8), 11, False), + # ("dgd", ("fiber", "dgd"), (0,1e3), 11, True), + # ("ortho_error", ("fiber", "ortho_error"), (0, 0.4), 11, False), # 0.05 * 180/pi = 2.86 degrees + ) + relative = False + + # ### model a) Dispersion, Slope + model_path = ".models/best_20250118_225918.tar" + # /mnt/c/Users/Joseph/Documents/Uni/ELM/MA/Masterarbeit/typst/images/BER_V_DGD.png + # results_path = "tolerance_results/datasets/results_map_20250127_194416.json" + # # + # ## dispersion sweep 17 steps (100 .. 3300 ps/nm) (Delta: -1600 .. 1600 ps/nm) + # results_path = "tolerance_results/datasets/results_map_20250119_020331.json" + # ## pol error sweep 9 steps (0 .. 0.8) (Delta: 0 .. 0.8) + # results_path = "tolerance_results/datasets/results_map_20250126_175011.json" + # ortho sweep 9 steps (0 .. 0.2) (Delta: -0.1 .. 0.1) + # results_path = "tolerance_results/datasets/results_map_20250126_175923.json" + # # ## dgd sweep 10 steps (11 .. 1010ps) (Delta: 1 .. 1000 ps) + # results_path = "tolerance_results/datasets/results_map_20250126_181310.json" + + + ### model b) Dispersion, Slope, pol_error 0.4 + # model_path = ".models/best_20250116_214816.tar" + + # ## dispersion sweep 17 steps (100 .. 3300 ps/nm) (Delta: -1600 .. 1600 ps/nm) + # results_path = "tolerance_results/datasets/results_map_20250119_104042.json" + # ## pol error sweep 9 steps (0 .. 0.8) (Delta: -0.4 .. 0.4) + # results_path = "tolerance_results/datasets/results_map_20250119_224924.json" + + + # ## model c) ortho error 0.1 + # model_path = ".models/best_20250117_122319.tar" + + # # dispersion sweep 17 steps (100 .. 3300 ps/nm) (Delta: -1600 .. 1600 ps/nm) + # results_path = "tolerance_results/datasets/results_map_20250119_165202.json" + # ortho sweep 9 steps (0 .. 0.2) (Delta: -0.1 .. 0.1) + # results_path = "tolerance_results/datasets/results_map_20250120_123540.json" + + + # # ## model d) dgd 10ps + # model_path = ".models/best_20250117_144001.tar" + + # # ## dispersion sweep 17 steps (100 .. 3300 ps/nm) (Delta: -1600 .. 1600 ps/nm) + # # results_path = "tolerance_results/datasets/results_map_20250119_184907.json" + # # ## dgd sweep 10 steps (11 .. 1010ps) (Delta: 1 .. 1000 ps) + # results_path="tolerance_results/datasets/results_map_20250120_110445.json" + + + runner = model_runner(results_path=results_path, model_path=model_path, conf_from_model=True) + + # print(runner.config) + # print(json.dumps(get_config(runner.config), indent=2)) + # sys.exit() + + forYourEyesOnly = False + generate_only = False + plot_BER_curves = True + + if forYourEyesOnly: + just_eyes(results_path, save_dir="tolerance_results/eye_plots") + sys.exit() + + if generate_only: + runner = model_runner() + generate_data(parameters, runner, silent=True, relative=relative) + sys.exit() + + if results_path is None: + generate_data(parameters, runner, relative=relative) + runner.model_path = model_path + runner.run_model_evaluations() + results_path = runner.results_path + + eval = evaluator(results_path) + eval.evaluate_datasets( + show_pre=False, + show_post=True, + show_SER=False, + verbose=False, + force=False, + plot=False, + ber_threshold=ber_threshold, + just_regen=True, ) - datasets = None - results = None + # just_eyes(results_path, save_dir="tolerance_results/eye_plots") - # datasets = "tolerance_results/datasets/datasets_list_20250110_223337.h5" - results = "tolerance_results/datasets/results_list_20250110_232639.h5" + if plot_BER_curves: + plot_bit_error_ratio(is_included, parameters, ber_threshold, eval) - runner = model_runner() - # generate_data(parameters, runner) - - - if results is None: - if datasets is None: - generate_data(parameters, runner) - else: - runner.load_datasets_from_file(datasets) - print(f"{len(runner.datasets)} loaded") - - runner.run_model_evaluation(model_path) - else: - runner.load_results_from_file(results) - - # print(runner.parameters) - # print(runner.results) - - eval = evaluator(runner.results) - eval.evaluate_datasets(plot=True) + try: + plt.show() + except KeyboardInterrupt: + plt.close() diff --git a/src/single-core-regen/util/complexNN.py b/src/single-core-regen/util/complexNN.py index f488f5b..017b751 100644 --- a/src/single-core-regen/util/complexNN.py +++ b/src/single-core-regen/util/complexNN.py @@ -441,8 +441,7 @@ class input_rotator(nn.Module): # return out -#### as defined by zhang et al - +#### as defined by zhang et alas class DropoutComplex(nn.Module): def __init__(self, p=0.5): @@ -464,7 +463,7 @@ class Scale(nn.Module): self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32)) def forward(self, x): - return x * self.scale + return x * torch.sqrt(self.scale) def __repr__(self): return f"Scale({self.size})" @@ -546,35 +545,31 @@ class EOActivation(nn.Module): raise ValueError("Size must be specified") self.size = size self.alpha = nn.Parameter(torch.rand(size)) - self.V_bias = nn.Parameter(torch.rand(size)) self.gain = nn.Parameter(torch.rand(size)) - # if bias: - # self.phase_bias = nn.Parameter(torch.zeros(size)) - # else: - # self.register_buffer("phase_bias", torch.zeros(size)) - # self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi) - self.register_buffer("responsivity", torch.ones(size)*0.9) - self.register_buffer("V_pi", torch.ones(size)*3) + self.V_bias = nn.Parameter(torch.rand(size)) + # self.register_buffer("gain", torch.ones(size)) + # self.register_buffer("responsivity", torch.ones(size)) + # self.register_buffer("V_pi", torch.ones(size)) self.reset_weights() def reset_weights(self): if "alpha" in self._parameters: self.alpha.data = torch.rand(self.size) - if "V_pi" in self._parameters: - self.V_pi.data = torch.rand(self.size)*3 + # if "V_pi" in self._parameters: + # self.V_pi.data = torch.rand(self.size)*3 if "V_bias" in self._parameters: self.V_bias.data = torch.randn(self.size) if "gain" in self._parameters: self.gain.data = torch.rand(self.size) - if "responsivity" in self._parameters: - self.responsivity.data = torch.ones(self.size)*0.9 + # if "responsivity" in self._parameters: + # self.responsivity.data = torch.ones(self.size)*0.9 # if "bias" in self._parameters: # self.phase_bias.data = torch.zeros(self.size) def forward(self, x: torch.Tensor): - phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8) - g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8) + phi_b = torch.pi * self.V_bias# / (self.V_pi) + g_phi = torch.pi * (self.alpha * self.gain)# * self.responsivity)# / (self.V_pi) intermediate = g_phi * x.abs().square() + phi_b return ( 1j diff --git a/src/single-core-regen/util/core.py b/src/single-core-regen/util/core.py new file mode 100644 index 0000000..1fb1c68 --- /dev/null +++ b/src/single-core-regen/util/core.py @@ -0,0 +1,105 @@ +# Copyright (c) 2015, Warren Weckesser. All rights reserved. +# This software is licensed according to the "BSD 2-clause" license. + +import hashlib +import h5py +import numpy as _np +from scipy.interpolate import interp1d as _interp1d +from scipy.ndimage import gaussian_filter as _gaussian_filter +from ._brescount import bres_curve_count as _bres_curve_count +from pathlib import Path + + +__all__ = ['grid_count'] + + +def grid_count(y, window_size, offset=0, size=None, fuzz=True, blur=0, bounds=None): + """ + Parameters + ---------- + `y` is the 1-d array of signal samples. + + `window_size` is the number of samples to show horizontally in the + eye diagram. Typically this is twice the number of samples in a + "symbol" (i.e. in a data bit). + + `offset` is the number of initial samples to skip before computing + the eye diagram. This allows the overall phase of the diagram to + be adjusted. + + `size` must be a tuple of two integers. It sets the size of the + array of counts, (height, width). The default is (800, 640). + + `fuzz`: If True, the values in `y` are reinterpolated with a + random "fuzz factor" before plotting in the eye diagram. This + reduces an aliasing-like effect that arises with the use of + Bresenham's algorithm. + + `bounds` must be a tuple of two floating point values, (ymin, ymax). + These set the y range of the returned array. If not given, the + bounds are `(y.min() - 0.05*A, y.max() + 0.05*A)`, where `A` is + `y.max() - y.min()`. + + Return Value + ------------ + Returns a numpy array of integers. + + """ + # hash input params + param_ob = (y, window_size, offset, size, fuzz, blur, bounds) + param_hash = hashlib.md5(str(param_ob).encode()).hexdigest() + cache_dir = Path.home()/".eyediagram"/".cache" + cache_dir.mkdir(parents=True, exist_ok=True) + if (cache_dir/param_hash).is_file(): + try: + with h5py.File(cache_dir/param_hash, "r") as infile: + counts = infile["counts"][:] + if counts.len() != 0: + return counts + except: + pass + + + + if size is None: + size = (800, 640) + height, width = size + dt = width / window_size + counts = _np.zeros((width, height), dtype=_np.int32) + + if bounds is None: + ymin = y.min() + ymax = y.max() + yamp = ymax - ymin + ymin = ymin - 0.05*yamp + ymax = ymax + 0.05*yamp + ymax = _np.ceil(ymax*10)/10 + ymin = _np.floor(ymin*10)/10 + else: + ymin, ymax = bounds + + start = offset + while start + window_size < len(y): + end = start + window_size + yy = y[start:end+1] + k = _np.arange(len(yy)) + xx = dt*k + if fuzz: + f = _interp1d(xx, yy, kind='cubic') + jiggle = dt*(_np.random.beta(a=3, b=3, size=len(xx)-2) - 0.5) + xx[1:-1] += jiggle + yd = f(xx) + else: + yd = yy + iyd = (height * (yd - ymin)/(ymax - ymin)).astype(_np.int32) + _bres_curve_count(xx.astype(_np.int32), iyd, counts) + + start = end + + if blur != 0: + counts = _gaussian_filter(counts, sigma=blur) + + with h5py.File(cache_dir/param_hash, "w") as outfile: + outfile.create_dataset("data", data=counts) + + return counts diff --git a/src/single-core-regen/util/datasets.py b/src/single-core-regen/util/datasets.py index fe6d01a..747610c 100644 --- a/src/single-core-regen/util/datasets.py +++ b/src/single-core-regen/util/datasets.py @@ -25,13 +25,14 @@ import multiprocessing as mp # def __len__(self): # return len(self.indices) + def load_from_file(datapath): - if str(datapath).endswith('.h5'): + if str(datapath).endswith(".h5"): symbols = None with h5py.File(datapath, "r") as infile: data = infile["data"][:] try: - symbols = infile["symbols"][:] + symbols = np.swapaxes(infile["symbols"][:], 0, 1) except KeyError: pass else: @@ -40,7 +41,7 @@ def load_from_file(datapath): return data, symbols -def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=1, device=None, dtype=None): +def load_data(config_path, skipfirst=0, skiplast=0, symbols=None, real=False, normalize=1, device=None, dtype=None): filepath = Path(config_path) filepath = filepath.parent.glob(filepath.name) config = configparser.ConfigParser() @@ -55,15 +56,23 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=1, d if symbols is None: symbols = int(config["glova"]["nos"]) - skipfirst - + data, orig_symbols = load_from_file(datapath) - data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)] - orig_symbols = orig_symbols[skipfirst:symbols+skipfirst] - timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps)) + data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps - skiplast * sps)] + orig_symbols = orig_symbols[skipfirst : symbols + skipfirst - skiplast] + timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps - skiplast * sps)) data *= np.sqrt(normalize) + launch_power = float(config["signal"]["laser_power"]) + output_power = float(config["signal"]["edfa_power"]) + + target_normalization = 10 ** (output_power / 10) / 10 ** (launch_power / 10) + # target_normalization *= 0.5 # allow 50% power loss, so the network can ignore parts of the signal + + data[:, 0:2] *= np.sqrt(target_normalization) + # if normalize: # # square gets normalized to 1, as the power is (proportional to) the square of the amplitude # a, b, c, d = data.T @@ -132,13 +141,15 @@ class FiberRegenerationDataset(Dataset): target_delay: float | int = 0, xy_delay: float | int = 0, drop_first: float | int = 0, + drop_last=0, dtype: torch.dtype = None, real: bool = False, device=None, # osnr: float|None = None, - polarisations = None, + polarisations=None, randomise_polarisations: bool = False, repeat_randoms: int = 1, + # cross_pol_interference: float = 0, **kwargs, ): """ @@ -172,6 +183,7 @@ class FiberRegenerationDataset(Dataset): assert drop_first >= 0, "drop_first must be non-negative" self.randomise_polarisations = randomise_polarisations + # self.cross_pol_interference = cross_pol_interference data_raw = None self.config = None @@ -181,6 +193,7 @@ class FiberRegenerationDataset(Dataset): data, config, orig_syms = load_data( file_path, skipfirst=drop_first, + skiplast=drop_last, symbols=kwargs.get("num_symbols", None), real=real, normalize=1000, @@ -192,7 +205,7 @@ class FiberRegenerationDataset(Dataset): self.orig_symbols = orig_syms else: self.orig_symbols = np.concat((self.orig_symbols, orig_syms), axis=-1) - + if data_raw is None: data_raw = data else: @@ -300,20 +313,18 @@ class FiberRegenerationDataset(Dataset): # fiber_out: [E_out_x, E_out_y, timestamps] # add noise related to amplification necessary due to splitting of the signal - gain_lin = output_dim*2 - edfa_nf = float(self.config["signal"]["edfa_nf"]) - nf_lin = 10**(edfa_nf/10) - f0 = float(self.config["glova"]["f0"]) - - noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9 - - noise = torch.randn_like(fiber_out[:2, :]) - noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1)) - noise = noise * torch.sqrt(noise_add / noise_power) - fiber_out[:2, :] += noise - + # gain_lin = output_dim*2 + # gain_lin = 1 + # edfa_nf = float(self.config["signal"]["edfa_nf"]) + # nf_lin = 10**(edfa_nf/10) + # f0 = float(self.config["glova"]["f0"]) + # noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9 + # noise = torch.randn_like(fiber_out[:2, :]) + # noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1)) + # noise = noise * torch.sqrt(noise_add / noise_power) + # fiber_out[:2, :] += noise # if osnr is None: # noisy = fiber_out[:2, :] @@ -324,7 +335,6 @@ class FiberRegenerationDataset(Dataset): # fiber_out: [E_out_x, E_out_y, timestamps, E_out_x_noisy, E_out_y_noisy] - if repeat_randoms > 1: fiber_in = fiber_in.repeat(1, 1, repeat_randoms) fiber_out = fiber_out.repeat(1, 1, repeat_randoms) @@ -334,12 +344,13 @@ class FiberRegenerationDataset(Dataset): if self.randomise_polarisations: angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms, device=fiber_out.device), 2) * torch.pi - # start_angle = torch.rand(1) * 2 * torch.pi - # angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk + start_angle = torch.rand(1) * 2 * torch.pi + angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk + angles = torch.randn(data_raw.shape[-1], device=fiber_out.device) * 2*torch.pi / 36 # sigma = 10 degrees # self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi else: angles = torch.zeros(data_raw.shape[-1], device=fiber_out.device) - + sin = torch.sin(angles) cos = torch.cos(angles) rot = torch.stack([torch.stack([cos, -sin], dim=1), torch.stack([sin, cos], dim=1)], dim=2) @@ -353,16 +364,14 @@ class FiberRegenerationDataset(Dataset): # 1 E_in_y, # 2 timestamps - # fiber_out: - # 0 E_out_x, - # 1 E_out_y, + # fiber_out: + # 0 E_out_x, + # 1 E_out_y, # 2 timestamps, - # 3 E_out_x_rot, - # 4 E_out_y_rot, + # 3 E_out_x_rot, + # 4 E_out_y_rot, # 5 angle - - # data_raw = torch.cat([data_raw, timestamps_doubled], dim=1) # data layout # [ [E_in_x, E_in_y, timestamps], @@ -374,9 +383,12 @@ class FiberRegenerationDataset(Dataset): self.fiber_out = fiber_out.unfold(dimension=-1, size=self.samples_per_slice, step=1) self.fiber_out = self.fiber_out.movedim(-2, 0) + # if self.randomise_polarisations: + # self.angles = torch.cumsum((torch.rand(self.fiber_out.shape[0]) - 0.5) * 2 * torch.pi * 2 / 5000, dim=0) + # self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1) # self.data = self.data.movedim(-2, 0) - # self.angles = torch.zeros(self.data.shape[0]) + # self.angles = torch.zeros(self.data.shape[0]) ... # ... # -> [no_slices, 2, 3, samples_per_slice] @@ -390,14 +402,14 @@ class FiberRegenerationDataset(Dataset): def __len__(self): return self.fiber_in.shape[0] - + def add_noise(self, data, osnr): - osnr_lin = 10**(osnr/10) + osnr_lin = 10 ** (osnr / 10) popt = torch.mean(data.abs().square().squeeze(), dim=-1) noise = torch.randn_like(data) pn = torch.mean(noise.abs().square().squeeze(), dim=-1) - - mult = torch.sqrt(popt/(pn*osnr_lin)) + + mult = torch.sqrt(popt / (pn * osnr_lin)) mult = mult * torch.eye(popt.shape[0], device=mult.device) mult = mult.to(dtype=noise.dtype) @@ -406,7 +418,6 @@ class FiberRegenerationDataset(Dataset): noisy = data + noise return noisy - def __getitem__(self, idx): if isinstance(idx, slice): return [self.__getitem__(i) for i in range(*idx.indices(len(self)))] @@ -418,6 +429,10 @@ class FiberRegenerationDataset(Dataset): output_dim = self.output_dim // 2 self.output_dim = output_dim * 2 + if not self.polarisations: + output_dim = 2 * output_dim + + fiber_in = self.fiber_in[idx].squeeze() fiber_out = self.fiber_out[idx].squeeze() @@ -427,85 +442,35 @@ class FiberRegenerationDataset(Dataset): fiber_in = fiber_in.view(fiber_in.shape[0], output_dim, -1) fiber_out = fiber_out.view(fiber_out.shape[0], output_dim, -1) - - # data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim] - - # data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1) - - # angle = self.angles[idx] - - # fiber_in: - # 0 E_in_x, - # 1 E_in_y, - # 2 timestamps - - # fiber_out: - # 0 E_out_x, - # 1 E_out_y, - # 2 timestamps, - # 3 E_out_x_rot, - # 4 E_out_y_rot, - # 5 angle - - center_angle = fiber_out[0, output_dim // 2, 0] + center_angle = fiber_out[5, output_dim // 2, 0] angles = fiber_out[5, :, 0] - plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone() plot_data_rot = fiber_out[3:5, output_dim // 2, 0].detach().clone() data = fiber_out[0:2, :, 0] - # fiber_out_plot_clean = fiber_out[:2, output_dim // 2, 0].detach().clone() + plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone() - - # if self.polarisations: - # rot = int(np.random.randint(2)*2-1) - # pol_flipped_data[0:1, :] = rot*data[0, :] - # pol_flipped_data[1, :] = rot*data[1, :] - # plot_data_rot[0] = rot*plot_data_rot[0] - # plot_data_rot[1] = rot*plot_data_rot[1] - # center_angle = center_angle + (rot - 1) * torch.pi/2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0 - # angles = angles + (rot - 1) * torch.pi/2 - - - # if self.randomise_polarisations: - # data = data.mT - # c = torch.cos(angle).unsqueeze(-1) - # s = torch.sin(angle).unsqueeze(-1) - # rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1) - # data = torch.bmm(data.mT.unsqueeze(0), rot.to(dtype=data.dtype)).squeeze(-1) - ... - - # angle = torch.zeros_like(angle) - - # for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter) - # angle_data = fiber_out[:2, :, :].reshape(2, -1).mean(dim=1).repeat(1, output_dim) - # angle_data2 = self.complex_max(fiber_out[:2, :, :].reshape(2, -1)).repeat(1, output_dim) - # sop = self.polarimeter(plot_data) - # angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1) - # angle = data_slice[1, 3, self.output_dim // 2, 0].real target = fiber_in[:2, output_dim // 2, 0] plot_target = fiber_in[:2, output_dim // 2, 0].detach().clone() target_timestamp = fiber_in[2, output_dim // 2, 0].real ... if self.polarisations: - rot = int(np.random.randint(2)*2-1) - data = rot*data - target = rot*target - plot_data_rot = rot*plot_data_rot - center_angle = center_angle + (rot - 1) * torch.pi/2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0 - angles = angles + (rot - 1) * torch.pi/2 + rot = int(np.random.randint(2) * 2 - 1) + data = rot * data + target = rot * target + plot_data_rot = rot * plot_data_rot + center_angle = center_angle + (rot - 1) * torch.pi / 2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0 + angles = angles + (rot - 1) * torch.pi / 2 pol_flipped_data = -data pol_flipped_target = -target - # data_timestamps = data[-1,:].real - # data = data[:-1, :] - # target_timestamp = target[-1].real - # target = target[:-1] - # plot_data = plot_data[:-1] - # transpose to interleave the x and y data in the output tensor data = data.transpose(0, 1).flatten().squeeze() + data = data / torch.sqrt(torch.ones(1) * len(data)) # power loss due to splitting pol_flipped_data = pol_flipped_data.transpose(0, 1).flatten().squeeze() + pol_flipped_data = pol_flipped_data / torch.sqrt( + torch.ones(1) * len(pol_flipped_data) + ) # power loss due to splitting # angle_data = angle_data.transpose(0, 1).flatten().squeeze() # angle_data2 = angle_data2.transpose(0,1).flatten().squeeze() center_angle = center_angle.flatten().squeeze() @@ -526,10 +491,10 @@ class FiberRegenerationDataset(Dataset): "y": target, "y_flipped": pol_flipped_target, "y_stacked": torch.cat([target, pol_flipped_target], dim=-1), - # "center_angle": center_angle, - # "angles": angles, + "center_angle": center_angle, + "angles": angles, "mean_angle": angles.mean(), - # "sop": sop, + # "sop": sop, # "angle_data": angle_data, # "angle_data2": angle_data2, "timestamp": target_timestamp, diff --git a/src/single-core-regen/util/eye_diagram.py b/src/single-core-regen/util/eye_diagram.py index 74ea0fe..99bef31 100644 --- a/src/single-core-regen/util/eye_diagram.py +++ b/src/single-core-regen/util/eye_diagram.py @@ -1,16 +1,23 @@ +from datetime import datetime +import json +from pathlib import Path +from typing import Literal +import h5py from matplotlib import pyplot as plt from matplotlib.colors import LinearSegmentedColormap +# from cmap import Colormap as cm import numpy as np from scipy.cluster.vq import kmeans2 import warnings import multiprocessing from rich.traceback import install -from rich import pretty -from rich import print install() -pretty.install() +# from rich import pretty +# from rich import print + +# pretty.install() def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1): @@ -21,6 +28,7 @@ def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1): xaxis = np.arange(0, len(signal)) / sps return np.vstack([xaxis, signal]) + def create_symbol_sequence(n_symbols, skew=1): np.random.seed(42) data = np.random.randint(0, 4, n_symbols) / 4 @@ -39,6 +47,14 @@ def generate_signal(data, sps): signal = np.convolve(data_padded, wavelet) signal = np.cumsum(signal) signal = signal[sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2] + mi, ma = np.min(signal), np.max(signal) + + signal = (signal - mi) / (ma - mi) + + mod = 0.8 + + signal *= mod + signal += 1 - mod return signal @@ -49,8 +65,8 @@ def normalization_with_noise(signal, noise=0): signal += awgn # min-max normalization - signal = signal - np.min(signal) - signal = signal / np.max(signal) + # signal = signal - np.min(signal) + # signal = signal / np.max(signal) return signal @@ -68,84 +84,248 @@ def generate_wavelet(sps, oversample=3): class eye_diagram: - def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4, multithreaded=True): + def __init__( + self, + data, + *, + channel_names=None, + horizontal_bins=256, + vertical_bins=1000, + n_levels=4, + multithreaded=True, + save_file_or_dir=None, + ): # data has shape [channels, 2, samples] # each sample has a timestamp and a value if data.ndim == 2: data = data[np.newaxis, :, :] - self.channel_names = channel_names self.raw_data = data - self.channels = data.shape[0] + + self.y_bins = np.zeros(1) + self.x_bins = np.zeros(1) + self.eye_data = np.zeros(1) + self.channel_names = channel_names + self.n_channels = data.shape[0] self.n_levels = n_levels - self.eye_stats = [{"success": False} for _ in range(self.channels)] + self.eye_stats = [{"success": False} for _ in range(self.n_channels)] self.horizontal_bins = horizontal_bins self.vertical_bins = vertical_bins self.multi_threaded = multithreaded + self.analysed = False self.eye_built = False - def generate_eye_data(self): - self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False) - self.y_bins = np.zeros((self.channels, self.vertical_bins)) - self.eye_data = np.zeros((self.channels, self.vertical_bins, self.horizontal_bins)) - datas = [self.raw_data[i] for i in range(self.channels)] - if self.multi_threaded: - with multiprocessing.Pool() as pool: - results = pool.map(self.generate_eye_data_single, datas) - for i, result in enumerate(results): - self.eye_data[i], self.y_bins[i] = result + self.save_file = save_file_or_dir + + def load_data(self, file=None): + file = self.save_file if file is None else file + + if file is None: + raise FileNotFoundError("No file specified.") + + self.save_file = str(file) + # self.file_or_dir = self.save_file + with h5py.File(file, "r") as infile: + self.y_bins = infile["y_bins"][:] + self.x_bins = infile["x_bins"][:] + self.eye_data = infile["eye_data"][:] + self.channel_names = infile.attrs["channel_names"] + self.n_channels = infile.attrs["n_channels"] + self.n_levels = infile.attrs["n_levels"] + self.eye_stats = infile.attrs["eye_stats"] + self.eye_stats = [json.loads(stat) for stat in self.eye_stats] + self.horizontal_bins = infile.attrs["horizontal_bins"] + self.vertical_bins = infile.attrs["vertical_bins"] + self.multi_threaded = infile.attrs["multithreaded"] + self.analysed = infile.attrs["analysed"] + self.eye_built = infile.attrs["eye_built"] + + def save_data(self, file_or_dir=None): + file_or_dir = self.save_file if file_or_dir is None else file_or_dir + if file_or_dir is None: + file = Path(f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5") + elif Path(file_or_dir).is_dir(): + file = Path(file_or_dir) / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5" else: - for i, data in enumerate(datas): - self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data) - self.eye_built = True - + file = Path(file_or_dir) + + # file.parent.mkdir(parents=True, exist_ok=True) + + self.save_file = str(file) + + with h5py.File(file, "w") as outfile: + outfile.create_dataset("eye_data", data=self.eye_data) + outfile.create_dataset("y_bins", data=self.y_bins) + outfile.create_dataset("x_bins", data=self.x_bins) + outfile.attrs["channel_names"] = self.channel_names + outfile.attrs["n_channels"] = self.n_channels + outfile.attrs["n_levels"] = self.n_levels + self.eye_stats = eye_diagram.convert_arrays(self.eye_stats) + outfile.attrs["eye_stats"] = [json.dumps(stat) for stat in self.eye_stats] + outfile.attrs["horizontal_bins"] = self.horizontal_bins + outfile.attrs["vertical_bins"] = self.vertical_bins + outfile.attrs["multithreaded"] = self.multi_threaded + outfile.attrs["analysed"] = self.analysed + outfile.attrs["eye_built"] = self.eye_built + + @staticmethod + def convert_arrays(input_object): + """ + convert ndarrays in (nested) dict to lists + """ + + if isinstance(input_object, np.ndarray): + return input_object.tolist() + elif isinstance(input_object, list): + return [eye_diagram.convert_arrays(old) for old in input_object] + elif isinstance(input_object, tuple): + return tuple(eye_diagram.convert_arrays(old) for old in input_object) + elif isinstance(input_object, dict): + dict_out = {} + for key, value in input_object.items(): + dict_out[key] = eye_diagram.convert_arrays(value) + return dict_out + return input_object + + def generate_eye_data( + self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None + ): + # modes: + # default: try to load eye data from file, if not found, generate and save + # load: try to load eye data from file, if not found, generate but don't save + # save: generate eye data and save + update_save = True + if mode == "load": + self.load_data(file_or_dir) + update_save = False + elif mode == "default": + try: + self.load_data(file_or_dir) + update_save = False + except (FileNotFoundError, IsADirectoryError): + pass + + if not self.eye_built: + update_save = True + self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False) + self.y_bins = np.zeros((self.n_channels, self.vertical_bins)) + self.eye_data = np.zeros((self.n_channels, self.vertical_bins, self.horizontal_bins)) + datas = [self.raw_data[i] for i in range(self.n_channels)] + if self.multi_threaded: + with multiprocessing.Pool() as pool: + results = pool.map(self.generate_eye_data_single, datas) + for i, result in enumerate(results): + self.eye_data[i], self.y_bins[i] = result + else: + for i, data in enumerate(datas): + self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data) + self.eye_built = True + + if mode == "save" or (mode == "default" and update_save): + self.save_data(file_or_dir) + def generate_eye_data_single(self, data): eye_data = np.zeros((self.vertical_bins, self.horizontal_bins)) data_min = np.min(data[1, :]) data_max = np.max(data[1, :]) + # round down/up to 1 decimal + data_min = np.floor(data_min*10)/10 + data_max = np.ceil(data_max*10)/10 + # data_range = data_max - data_min + # data_min -= 0.1 * data_range + # data_max += 0.1 * data_range + # data_min = -0.05 + # data_max += 0.05 + # data[1,:] -= np.min(data[1, :]) + # data[1,:] /= np.max(data[1, :]) + # data_min = 0 + # data_max = 1 y_bins = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False) - t_vals = data[0, :] % 2 - val_vals = data[1, :] + t_vals = data[0, :] % 2 # + np.random.randn(*data[0, :].shape) * 1 / (512) + val_vals = data[1, :] # + np.random.randn(*data[1, :].shape) * 1 / (320) x_indices = np.digitize(t_vals, self.x_bins) - 1 y_indices = np.digitize(val_vals, y_bins) - 1 np.add.at(eye_data, (y_indices, x_indices), 1) return eye_data, y_bins - def plot(self, title="Eye Diagram", stats=True, all_stats=True, show=True): + def plot( + self, + title="Eye Diagram", + stats=True, + all_stats=True, + show=True, + mode: Literal["default", "load", "save", "nosave"] = "default", + # save_images = False, + # image_dir = None, + # cmap=None, + ): + if stats and not self.analysed: + self.analyse(mode=mode) if not self.eye_built: - self.generate_eye_data() + self.generate_eye_data(mode=mode) cmap = LinearSegmentedColormap.from_list( "eyemap", - [(0, "white"), (0.1, "blue"), (0.2, "cyan"), (0.5, "green"), (0.8, "yellow"), (0.9, "red"), (1, "magenta")], + [ + (0, "#FFFFFF00"), + (0.1, "blue"), + (0.2, "cyan"), + (0.5, "green"), + (0.8, "yellow"), + (0.9, "red"), + (1, "magenta"), + ], ) - if self.channels % 2 == 0: + # cmap = cm('google:turbo_r' if cmap is None else cmap) + # first = cmap(-1) + # cmap = cmap.to_mpl() + # cmap.set_under(first, alpha=0) + if self.n_channels % 2 == 0: rows = 2 - cols = self.channels // 2 + cols = self.n_channels // 2 else: - cols = int(np.ceil(np.sqrt(self.channels))) - rows = int(np.ceil(self.channels / cols)) + cols = int(np.ceil(np.sqrt(self.n_channels))) + rows = int(np.ceil(self.n_channels / cols)) fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False) fig.suptitle(title) fig.tight_layout() ax = np.atleast_1d(ax).transpose().flatten() - for i in range(self.channels): - ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}") - if (i+1) % rows == 0: + for i in range(self.n_channels): + ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i + 1}") + if (i + 1) % rows == 0: ax[i].set_xlabel("Symbol") if i < rows: ax[i].set_ylabel("Amplitude") ax[i].grid() + ax[i].set_axisbelow(True) ax[i].imshow( - self.eye_data[i], + self.eye_data[i] - 0.1, origin="lower", aspect="auto", cmap=cmap, extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]], + interpolation="gaussian", + vmin=0, + zorder=3, ) ax[i].set_xlim((self.x_bins[0], self.x_bins[-1])) ymin = np.min(self.y_bins[:, 0]) ymax = np.max(self.y_bins[:, -1]) yspan = ymax - ymin ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan)) + # if save_images: + # image_dir = "images_out" if image_dir is None else image_dir + # image_path = Path(image_dir) / (slugify(f"{datetime.now().strftime("%Y%m%d_%H%M%S")}_{title.replace(" ","_")}_{self.channel_names[i].replace(" ", "_") if self.channel_names is not None else f"{i + 1}"}_{ymin:.1f}_{ymax:.1f}") + ".png") + # image_path.parent.mkdir(parents=True, exist_ok=True) + # # plt.imsave( + # # image_path, + # # self.eye_data[i] - 0.1, + # # origin="lower", + # # # aspect="auto", + # # cmap=cmap, + # # # extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]], + # # # interpolation="gaussian", + # # vmin=0, + # # # zorder=3, + # # ) if stats and self.eye_stats[i]["success"]: # # add min_area above the plot # ax[i].annotate( @@ -159,7 +339,7 @@ class eye_diagram: if all_stats: ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--") - y_ticks = (*self.eye_stats[i]["levels"],*self.eye_stats[i]["thresholds"]) + y_ticks = (*self.eye_stats[i]["levels"], *self.eye_stats[i]["thresholds"]) # y_ticks = np.sort(y_ticks) ax[i].set_yticks(y_ticks) # add arrows for amplitudes @@ -230,24 +410,24 @@ class eye_diagram: if show: plt.show() return fig - + @staticmethod def calculate_thresholds(levels): ret = np.cumsum(levels, dtype=float) ret[2:] = ret[2:] - ret[:-2] - return ret[1:]/2 + return ret[1:] / 2 def analyse_single(self, data, index): warnings.filterwarnings("error") eye_stats = {} - eye_stats["channel_name"] = str(index+1) if self.channel_names is None else self.channel_names[index] + eye_stats["channel_name"] = str(index + 1) if self.channel_names is None else self.channel_names[index] try: approx_levels = eye_diagram.approximate_levels(data, self.n_levels) time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels) - eye_stats["time_midpoint_calc"] = (time_bounds[0] + time_bounds[1]) / 2 - eye_stats["time_midpoint"] = 1.0 + eye_stats["time_midpoint"] = float((time_bounds[0] + time_bounds[1]) / 2) + # eye_stats["time_midpoint"] = 1.0 eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels( data, approx_levels, time_bounds @@ -257,9 +437,7 @@ class eye_diagram: eye_stats["amplitudes"] = np.diff(eye_stats["levels"]) - eye_stats["heights"] = eye_diagram.calculate_eye_heights( - eye_stats["amplitude_clusters"] - ) + eye_stats["heights"] = eye_diagram.calculate_eye_heights(eye_stats["amplitude_clusters"]) eye_stats["widths"], eye_stats["time_clusters"] = eye_diagram.calculate_eye_widths( data, eye_stats["levels"] @@ -291,17 +469,39 @@ class eye_diagram: warnings.resetwarnings() return eye_stats + def analyse( + self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None + ): + # modes: + # default: try to load eye data from file, if not found, generate and save + # load: try to load eye data from file, if not found, generate but don't save + # save: generate eye data and save + update_save = True + if mode == "load": + self.load_data(file_or_dir) + update_save = False + elif mode == "default": + try: + self.load_data(file_or_dir) + update_save = False + except (FileNotFoundError, IsADirectoryError): + pass - def analyse(self): - self.eye_stats = [] - if self.multi_threaded: - with multiprocessing.Pool() as pool: - results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.channels)]) - for i, result in enumerate(results): - self.eye_stats.append(result) - else: - for i in range(self.channels): - self.eye_stats.append(self.analyse_single(self.raw_data[i], i)) + if not self.analysed: + update_save = True + self.eye_stats = [] + if self.multi_threaded: + with multiprocessing.Pool() as pool: + results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.n_channels)]) + for i, result in enumerate(results): + self.eye_stats.append(result) + else: + for i in range(self.n_channels): + self.eye_stats.append(self.analyse_single(self.raw_data[i], i)) + self.analysed = True + + if mode == "save" or (mode == "default" and update_save): + self.save_data(file_or_dir) @staticmethod def approximate_levels(data, levels): @@ -443,7 +643,7 @@ class eye_diagram: if __name__ == "__main__": - length = int(2**14) + length = int(2**16) # data = generate_sample_data(length, noise=1) # data1 = generate_sample_data(length, noise=0.01) # data2 = generate_sample_data(length, noise=0.01, skew=1.2) @@ -451,13 +651,13 @@ if __name__ == "__main__": # data = np.stack([data, data1, data2, data3]) - data = generate_sample_data(length, noise=0.005) - eye = eye_diagram(data, horizontal_bins=256, vertical_bins=256) - eye.analyse() - attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area") - for i, channel in enumerate(eye.eye_stats): - print(f"Channel {i}") - print_data = {attr: channel[attr] for attr in attrs} - print(print_data) + data = generate_sample_data(length, noise=0.0000) + eye = eye_diagram(data, horizontal_bins=1056, vertical_bins=1200) + eye.plot(mode="nosave", stats=False) + # attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area") + # for i, channel in enumerate(eye.eye_stats): + # print(f"Channel {i}") + # print_data = {attr: channel[attr] for attr in attrs} + # print(print_data) - eye.plot() + # eye.plot() diff --git a/src/single-core-regen/util/mpl.py b/src/single-core-regen/util/mpl.py new file mode 100644 index 0000000..04a13d0 --- /dev/null +++ b/src/single-core-regen/util/mpl.py @@ -0,0 +1,122 @@ +# Copyright (c) 2015, Warren Weckesser. All rights reserved. +# This software is licensed according to the "BSD 2-clause" license. + +# modified by Joseph Hopfmüller in 2025, +# for integration into optical regeneration analysis scripts + +from pathlib import Path +from matplotlib.colors import LinearSegmentedColormap +import matplotlib.colors as colors +import numpy as _np +from .core import grid_count as _grid_count +import matplotlib.pyplot as _plt +import numpy as np +from scipy.ndimage import gaussian_filter + + +# from ._common import _common_doc + + +__all__ = ["eyediagram"] # , 'eyediagram_lines'] + + +# def eyediagram_lines(y, window_size, offset=0, **plotkwargs): +# """ +# Plot an eye diagram using matplotlib by repeatedly calling the `plot` +# function. +# + +# """ +# start = offset +# while start < len(y): +# end = start + window_size +# if end > len(y): +# end = len(y) +# yy = y[start:end+1] +# _plt.plot(_np.arange(len(yy)), yy, 'k', **plotkwargs) +# start = end + +# eyediagram_lines.__doc__ = eyediagram_lines.__doc__.replace("", +# _common_doc) + + +eyemap = LinearSegmentedColormap.from_list( + "eyemap", + [ + (0, "#0000FF00"), + (0.1, "blue"), + (0.2, "cyan"), + (0.5, "green"), + (0.8, "yellow"), + (0.9, "red"), + (1, "magenta"), + ], +) + + +def eyediagram( + y, + window_size, + offset=0, + colorbar=False, + show=False, + save_im=False, + overwrite=False, + blur: int | bool = True, + save_path="out.png", + bounds=None, + **imshowkwargs, +): + """ + Plot an eye diagram using matplotlib by creating an image and calling + the `imshow` function. + + """ + if bounds is None: + ymax = y.max() + ymin = y.min() + yamp = ymax - ymin + ymin = ymin - 0.05 * yamp + ymax = ymax + 0.05 * yamp + ymin = np.floor(ymin * 10) / 10 + ymax = np.ceil(ymax * 10) / 10 + bounds = (ymin, ymax) + counts = _grid_count(y, window_size, offset, bounds=bounds, size=(1000, 1200), blur=int(blur)) + counts = counts.astype(_np.float32) + origin = imshowkwargs.pop("origin", "lower") + cmap: colors.Colormap = imshowkwargs.pop("cmap", eyemap) + vmin = imshowkwargs.pop("vmin", 1) + vmax = imshowkwargs.pop("vmax", None) + cmap.set_under("white", alpha=0) + + if show: + _plt.imshow( + counts.T[::-1, :], + extent=[0, 2, *bounds], + origin=origin, + cmap=cmap, + vmin=vmin, + vmax=vmax, + **imshowkwargs, + ) + _plt.grid() + if colorbar: + _plt.colorbar() + + if Path(save_path).is_file() and not overwrite: + save_im = False + if save_im: + from PIL import Image + arr = counts.T[::-1, :] + if origin == "lower": + arr = arr[::-1] + arr = (arr-arr.min())/(arr.max()-arr.min()) + image = Image.fromarray((cmap(arr)[:, :, :] * 255).astype(np.uint8)) + image.save(save_path) + # print("-") + + if show: + _plt.show() + + +# eyediagram.__doc__ = eyediagram.__doc__.replace("", _common_doc)