""" tests a given model for tolerance against variations in - fiber length - baudrate - OSNR CD, PMD, baudrate need different datasets, osnr is modeled as awgn added to the data before feeding into the model """ from datetime import datetime from typing import Literal from matplotlib import pyplot as plt import numpy as np from pathlib import Path import h5py import torch import util from hypertraining.settings import GlobalSettings, DataSettings, ModelSettings, OptimizerSettings, PytorchSettings from hypertraining import models from signal_gen.generate_signal import single_run, get_config import json class NestedParameterIterator: def __init__(self, parameters): """ parameters: dict with key and value """ # self.parameters = parameters self.names = [] self.ranges = [] self.configs = [] for k, v in parameters.items(): self.names.append(k) self.ranges.append(v["range"]) self.configs.append(v["config"]) self.n_parameters = len(self.ranges) self.idx = 0 self.range_idx = [0] * self.n_parameters self.range_len = [len(r) for r in self.ranges] self.length = int(np.prod(self.range_len)) self.out = [] for i in range(self.length): self.out.append([]) for j in range(self.n_parameters): element = {self.names[j]: {"value": self.ranges[j][self.range_idx[j]], "config": self.configs[j]}} self.out[i].append(element) self.range_idx[-1] += 1 # update range_idx back to front for j in range(self.n_parameters - 1, -1, -1): if self.range_idx[j] == self.range_len[j]: self.range_idx[j] = 0 self.range_idx[j - 1] += 1 ... def __next__(self): if self.idx == self.length: raise StopIteration self.idx += 1 return self.out[self.idx - 1] def __iter__(self): return self class model_runner: def __init__( self, # length_range: tuple[int | float] = (50e3, 50e3), # length_steps: int = 1, # length_log: bool = False, # baudrate_range: tuple[int | float] = (10e9, 10e9), # baudrate_steps: int = 1, # baudrate_log: bool = False, # osnr_range: tuple[int | float] = (16, 16), # osnr_steps: int = 1, # osnr_log: bool = False, # dataset_dir: str = "data", # dataset_datetime_glob: str = "*", results_dir: str = "tolerance_results/datasets", # model_dir: str = ".models", config: str = "signal_generation.ini", config_dir: str = None, debug: bool = False, ): """ length_range: lower and upper limit of length, in meters length_step: step size of length, in meters baudrate_range: lower and upper limit of baudrate, in Bd baudrate_step: step size of baudrate, in Bd osnr_range: lower and upper limit of osnr, in dB osnr_step: step size of osnr, in dB dataset_dir: directory containing datasets dataset_datetime_glob: datetime glob pattern for dataset files results_dir: directory to save results model_dir: directory containing models """ self.debug = debug self.parameters = {} self.iter = None # self.update_length_range(length_range, length_steps, length_log) # self.update_baudrate_range(baudrate_range, baudrate_steps, baudrate_log) # self.update_osnr_range(osnr_range, osnr_steps, osnr_log) # self.data_dir = Path(dataset_dir) # self.data_datetime_glob = dataset_datetime_glob self.results_dir = Path(results_dir) # self.model_dir = Path(model_dir) config_dir = config_dir or Path(__file__).parent self.config = config_dir / config torch.serialization.add_safe_globals([ *util.complexNN.__all__, GlobalSettings, DataSettings, ModelSettings, OptimizerSettings, PytorchSettings, models.regenerator, torch.nn.utils.parametrizations.orthogonal, ]) self.load_model() self.datasets = [] # def register_parameter(self, name, config): # self.parameters.append({"name": name, "config": config}) def load_results_from_file(self, path): data, meta = self.load_from_file(path) self.results = [d.decode() for d in data] self.parameters = meta["parameters"] ... def load_datasets_from_file(self, path): data, meta = self.load_from_file(path) self.datasets = [d.decode() for d in data] self.parameters = meta["parameters"] ... def update_parameter_range(self, name, config, range, steps, log): self.parameters[name] = {"config": config, "range": self.update_range(*range, steps, log)} def generate_iterations(self): if len(self.parameters) == 0: raise ValueError("No parameters registered") self.iter = NestedParameterIterator(self.parameters) def generate_datasets(self): # get base config config = get_config(self.config) if self.iter is None: self.generate_iterations() for params in self.iter: current_settings = [] # params is a list of dictionaries with keys "name", containing a dict with keys "value", "config" for param in params: for name, settings in param.items(): current_settings.append({name: settings["value"]}) self.nested_set(config, settings["config"], settings["value"]) settings_strs = [] for setting in current_settings: name = list(setting)[0] settings_strs.append(f"{name}: {float(setting[name]):.2e}") settings_str = ", ".join(settings_strs) print(f"Generating dataset for [{settings_str}]") # TODO look for existing datasets _, _, path = single_run(config) self.datasets.append(str(path)) datasets_list_path = self.build_path("datasets_list", parent_dir=self.results_dir, timestamp="back") metadata = {"parameters": self.parameters} data = np.array(self.datasets, dtype="S") self.save_to_file(datasets_list_path, data, **metadata) @staticmethod def nested_set(dic, keys, value): for key in keys[:-1]: dic = dic.setdefault(key, {}) dic[keys[-1]] = value ## Dataset and model loading # def find_datasets(self, data_dir=None, data_datetime_glob=None): # # date-time-sps-nos-length-gamma-alpha-D-S-PAM4-birefsteps-deltabeta-symbolrate.ini # data_dir = data_dir or self.data_dir # data_datetime_glob = data_datetime_glob or self.data_datetime_glob # self.datasets = {} # data_dir = Path(data_dir) # for length in self.lengths: # for baudrate in self.baudrates: # # dataset_glob = self.data_datetime_glob + f"*-*-{int(length)}-*-*-*-*-PAM4-*-*-{int(baudrate/1e9)}.ini" # dataset_glob = data_datetime_glob + f"-*-*-{int(length)}-*-*-*-*-PAM4-*-*.ini" # datasets = [f for f in data_dir.glob(dataset_glob)] # if len(datasets) == 0: # continue # self.datasets[length] = {} # if len(datasets) > 1: # print( # f"multiple datasets found for [{length / 1000:.1f} km, {int(baudrate / 1e9)} GBd]. Using the newest dataset." # ) # # get newest file from creation date # datasets.sort(key=lambda x: x.stat().st_ctime) # self.datasets[length][baudrate] = str(datasets[-1]) def load_dataset(self, dataset_path): if self.checkpoint_dict is None: raise ValueError("Model must be loaded before dataset") if self.dataset_path is None: self.dataset_path = dataset_path elif self.dataset_path == dataset_path: return symbols = self.checkpoint_dict["settings"]["data_settings"].symbols data_size = self.checkpoint_dict["settings"]["data_settings"].output_size dtype = getattr(torch, self.checkpoint_dict["settings"]["data_settings"].dtype) drop_first = self.checkpoint_dict["settings"]["data_settings"].drop_first randomise_polarisations = self.checkpoint_dict["settings"]["data_settings"].randomise_polarisations polarisations = self.checkpoint_dict["settings"]["data_settings"].polarisations num_symbols = None if self.debug: num_symbols = 1000 config_path = Path(dataset_path) dataset = util.datasets.FiberRegenerationDataset( file_path=config_path, symbols=symbols, output_dim=data_size, drop_first=drop_first, dtype=dtype, real=not dtype.is_complex, randomise_polarisations=randomise_polarisations, polarisations=polarisations, num_symbols=num_symbols, # device="cuda" if torch.cuda.is_available() else "cpu", ) self.dataloader = torch.utils.data.DataLoader( dataset, batch_size=2**14, pin_memory=True, num_workers=24, prefetch_factor=8, shuffle=False ) return self.dataloader.dataset.orig_symbols # run model # return results as array: [fiber_in, fiber_out, fiber_out_noisy, regen_out] def load_model(self, model_path: str | None = None): if model_path is None: self.model = None self.model_path = None self.checkpoint_dict = None return path = Path(model_path) if self.model_path is None: self.model_path = path elif path == self.model_path: return self.dataset_path = None # reset dataset path, as the shape depends on the model self.checkpoint_dict = torch.load(path, weights_only=True) dims = self.checkpoint_dict["model_kwargs"].pop("dims") self.model = models.regenerator(*dims, **self.checkpoint_dict["model_kwargs"]) self.model.load_state_dict(self.checkpoint_dict["model_state_dict"]) ## Model evaluation def run_model_evaluation(self, model_path: str, datasets: str | None = None): self.load_model(model_path) # iterate over datasets and osnr values: # load dataset, add noise, run model, return results # save results to file self.results = [] if datasets is not None: self.load_datasets_from_file(datasets) n_datasets = len(self.datasets) for i, dataset_path in enumerate(self.datasets): conf = get_config(dataset_path) mpath = Path(model_path) model_base = mpath.stem print(f"({1+i}/{n_datasets}) Running model {model_base} with dataset {dataset_path.split('/')[-1]}") results_path = self.build_path( dataset_path.split("/")[-1], parent_dir=Path(self.results_dir) / model_base ) orig_symbols = self.load_dataset(dataset_path) data, loss = self.run_model() metadata = { "model_path": model_path, "dataset_path": dataset_path, "loss": loss, "sps": conf["glova"]["sps"], "orig_symbols": orig_symbols # "config": conf, # "checkpoint_dict": self.checkpoint_dict, # "nos": self.dataloader.dataset.num_symbols, } self.save_to_file(results_path, data, **metadata) self.results.append(str(results_path)) results_list_path = self.build_path("results_list", parent_dir=self.results_dir, timestamp="back") metadata = {"parameters": self.parameters} data = np.array(self.results, dtype="S") self.save_to_file(results_list_path, data, **metadata) def run_model(self): loss = 0 datas = [] self.model.eval() model = self.model.to("cuda" if torch.cuda.is_available() else "cpu") with torch.no_grad(): for batch in self.dataloader: x = batch["x_stacked"] y = batch["y_stacked"] fiber_in = batch["plot_target"] # fiber_out = batch["plot_clean"] fiber_out = batch["plot_data"] timestamp = batch["timestamp"] angle = batch["mean_angle"] x = x.to("cuda" if torch.cuda.is_available() else "cpu") angle = angle.to("cuda" if torch.cuda.is_available() else "cpu") regen = model(x, -angle) regen = regen.to("cpu") loss += util.complexNN.complex_mse_loss(regen, y, power=True).item() # shape: [batch_size, 4] plot_regen = regen[:, :2] plot_regen = plot_regen.view(plot_regen.shape[0], -1, 2) plot_regen = plot_regen[:, plot_regen.shape[1] // 2, :] data_out = torch.cat( ( fiber_in, fiber_out, # fiber_out_noisy, plot_regen, timestamp.view(-1, 1), ), dim=1, ) datas.append(data_out) data_out = torch.cat(datas, dim=0).numpy() return data_out, loss ## File I/O @staticmethod def save_to_file(path: str, data: np.ndarray, **metadata: dict): # create directory if it doesn't exist path.parent.mkdir(parents=True, exist_ok=True) with h5py.File(path, "w") as outfile: outfile.create_dataset("data", data=data) for key, value in metadata.items(): if isinstance(value, dict): value = json.dumps(model_runner.convert_arrays(value)) outfile.attrs[key] = value @staticmethod def convert_arrays(dict_in): """ convert ndarrays in (nested) dict to lists """ dict_out = {} for key, value in dict_in.items(): if isinstance(value, dict): dict_out[key] = model_runner.convert_arrays(value) elif isinstance(value, np.ndarray): dict_out[key] = value.tolist() else: dict_out[key] = value return dict_out @staticmethod def load_from_file(path: str): with h5py.File(path, "r") as infile: data = infile["data"][:] metadata = {} for key in infile.attrs.keys(): if isinstance(infile.attrs[key], (str, bytes, bytearray)): try: metadata[key] = json.loads(infile.attrs[key]) except json.JSONDecodeError: metadata[key] = infile.attrs[key] else: metadata[key] = infile.attrs[key] return data, metadata ## Utility functions @staticmethod def logrange(start, stop, num, endpoint=False): lower, upper = np.log10((start, stop)) return np.logspace(lower, upper, num=num, endpoint=endpoint, base=10) @staticmethod def build_path( *elements, parent_dir: str | Path | None = None, filetype="h5", timestamp: Literal["no", "front", "back"] = "no" ): suffix = f".{filetype}" if not filetype.startswith(".") else filetype if timestamp != "no": ts = datetime.now().strftime("%Y%m%d_%H%M%S") if timestamp == "front": elements = (ts, *elements) elif timestamp == "back": elements = (*elements, ts) path = "_".join(elements) path += suffix if parent_dir is not None: path = Path(parent_dir) / path return path @staticmethod def update_range(min, max, n_steps, log): if log: range = model_runner.logrange(min, max, n_steps, endpoint=True) else: range = np.linspace(min, max, n_steps, endpoint=True) return range class model_evaluation_result: def __init__( self, *, length=None, baudrate=None, osnr=None, model_path=None, dataset_path=None, loss=None, sps=None, **kwargs, ): self.length = length self.baudrate = baudrate self.osnr = osnr self.model_path = model_path self.dataset_path = dataset_path self.loss = loss self.sps = sps self.sers = None self.bers = None self.eye_stats = None class evaluator: def __init__(self, datasets: list[str]): """ datasets: iterable of dataset paths data_dir: directory containing datasets """ self.datasets = datasets self.results = [] def evaluate_datasets(self, plot=False): ## iterate over datasets # load dataset for dataset in self.datasets: model, dataset_name = dataset.split("/")[-2:] print(f"\nEvaluating model {model} with dataset {dataset_name}") data, metadata = model_runner.load_from_file(dataset) result = model_evaluation_result(**metadata) data = self.prepare_data(data, sps=metadata["sps"]) try: sym_x, sym_y = metadata["orig_symbols"] except (TypeError, KeyError, ValueError): sym_x, sym_y = None, None self.evaluate_eye(data, result, title=dataset.split("/")[-1], plot=False) self.evaluate_ser_ber(data, result, sym_x, sym_y) print("BER:") self.print_dict(result.bers["regen"]) print() print("SER:") self.print_dict(result.sers["regen"]) print() self.results.append(result) if plot: plt.show() def evaluate_eye(self, data, result, title=None, plot=False): eye = util.eye_diagram.eye_diagram( data, channel_names=[ "fiber_in_x", "fiber_in_y", # "fiber_out_x", # "fiber_out_y", "fiber_out_x", "fiber_out_y", "regen_x", "regen_y", ], ) eye.analyse() eye.plot(title=title or "Eye diagram", show=plot) result.eye_stats = eye.eye_stats return eye.eye_stats ... def evaluate_ser_ber(self, data, result, sym_x=None, sym_y=None): if result.eye_stats is None: self.evaluate_eye(data, result) symbols = [] sers = {"fiber_out": {"x": None, "y": None}, "regen": {"x": None, "y": None}} bers = {"fiber_out": {"x": None, "y": None}, "regen": {"x": None, "y": None}} for channel_data, stats in zip(data, result.eye_stats): timestamps = channel_data[0] dat = channel_data[1] channel_name = stats["channel_name"] if stats["success"]: thresholds = stats["thresholds"] time_midpoint = stats["time_midpoint"] else: if channel_name.endswith("x"): thresholds = result.eye_stats[0]["thresholds"] time_midpoint = result.eye_stats[0]["time_midpoint"] elif channel_name.endswith("y"): thresholds = result.eye_stats[1]["thresholds"] time_midpoint = result.eye_stats[1]["time_midpoint"] else: levels = np.linspace(np.min(dat), np.max(dat), 4) thresholds = util.eye_diagram.eye_diagram.calculate_thresholds(levels) time_midpoint = 1.0 # time_offset = time_midpoint - 0.5 # # time_offset = 0 # index_offset = np.argmin(np.abs((timestamps - time_offset) % 1.0)) nos = len(timestamps) // result.sps # idx = np.arange(index_offset, len(timestamps), result.sps).astype(int) # if time_offset < 0: # idx = np.insert(idx, 0, 0) idx = list(range(0,len(timestamps),result.sps)) idx = idx[:nos] data_sampled = dat[idx] detected_symbols = self.detect_symbols(data_sampled, thresholds) symbols.append({"channel_name": channel_name, "symbols": detected_symbols}) symbols_x_gt = sym_x or symbols[0]["symbols"] symbols_y_gt = sym_y or symbols[1]["symbols"] symbols_x_fiber_out = symbols[2]["symbols"] symbols_y_fiber_out = symbols[3]["symbols"] symbols_x_regen = symbols[4]["symbols"] symbols_y_regen = symbols[5]["symbols"] sers["fiber_out"]["x"], bers["fiber_out"]["x"] = self.calculate_ser_ber(symbols_x_gt, symbols_x_fiber_out) sers["fiber_out"]["y"], bers["fiber_out"]["y"] = self.calculate_ser_ber(symbols_y_gt, symbols_y_fiber_out) sers["regen"]["x"], bers["regen"]["x"] = self.calculate_ser_ber(symbols_x_gt, symbols_x_regen) sers["regen"]["y"], bers["regen"]["y"] = self.calculate_ser_ber(symbols_y_gt, symbols_y_regen) result.sers = sers result.bers = bers @staticmethod def calculate_ser_ber(symbols_gt, symbols): # levels = 4 # symbol difference -> bit error count # |rx - tx| = 0 -> 0 # |rx - tx| = 1 -> 1 # |rx - tx| = 2 -> 2 # |rx - tx| = 3 -> 1 # assuming gray coding -> 0: 00, 1: 01, 2: 11, 3: 10 bec_map = {0: 0, 1: 1, 2: 2, 3: 1, np.nan: 2} ser = {} ber = {} ser["n_symbols"] = len(symbols_gt) ser["n_errors"] = np.sum(symbols != symbols_gt) ser["total"] = float(ser["n_errors"] / ser["n_symbols"]) bec = np.vectorize(bec_map.get)(np.abs(symbols - symbols_gt)) bit_errors = np.sum(bec) ber["n_bits"] = len(symbols_gt) * 2 ber["n_errors"] = bit_errors ber["total"] = float(ber["n_errors"] / ber["n_bits"]) return ser, ber @staticmethod def print_dict(d: dict, indent=2, logarithmic=False, level=0): for key, value in d.items(): if isinstance(value, dict): print(f"{' ' * indent * level}{key}:") evaluator.print_dict(value, indent=indent, logarithmic=logarithmic, level=level + 1) else: if isinstance(value, float): if logarithmic: if value == 0: value = -np.inf else: value = np.log10(value) print(f"{' ' * indent * level}{key}: {value:.2e}\t", end="") else: print(f"{' ' * indent * level}{key}: {value}\t", end="") print() @staticmethod def detect_symbols(samples, thresholds=None): thresholds = (1 / 6, 3 / 6, 5 / 6) if thresholds is None else thresholds thresholds = (-np.inf, *thresholds, np.inf) symbols = np.digitize(samples, thresholds) - 1 return symbols @staticmethod def prepare_data(data, sps=None): data = data.transpose(1, 0) timestamps = data[-1].real data = data[:-1] if sps is not None: timestamps /= sps # data = np.stack( # ( # *data[0:2], # fiber_in_x, fiber_in_y # # *data_[2:4], # fiber_out_x, fiber_out_y # *data[4:6], # fiber_out_noisy_x, fiber_out_noisy_y # *data[6:8], # regen_out_x, regen_out_y # ), # axis=0, # ) data_eye = [] for channel_values in data: channel_values = np.square(np.abs(channel_values)) data_eye.append(np.stack((timestamps, channel_values), axis=0)) data_eye = np.stack(data_eye, axis=0) return data_eye def generate_data(parameters, runner=None): runner = runner or model_runner() for param in parameters: runner.update_parameter_range(*param) runner.generate_iterations() print(f"{runner.iter.length} parameter combinations") runner.generate_datasets() return runner if __name__ == "__main__": model_path = ".models/best_20250110_191149.tar" # D 17, OSNR 100, delta_beta 0.14, baud 10e9 parameters = ( # name, config keys, (min, max), n_steps, log # ("D", ("fiber", "d"), (28,30), 3, False), # ("S", ("fiber", "s"), (0, 0.058), 2, False), ("OSNR", ("signal", "osnr"), (20, 40), 5, False), # ("PMD", ("fiber", "max_delta_beta"), (0, 0.28), 3, False), # ("Baud", ("glova", "symbolrate"), (10e9, 100e9), 3, True), ) datasets = None results = None # datasets = "tolerance_results/datasets/datasets_list_20250110_223337.h5" results = "tolerance_results/datasets/results_list_20250110_232639.h5" runner = model_runner() # generate_data(parameters, runner) if results is None: if datasets is None: generate_data(parameters, runner) else: runner.load_datasets_from_file(datasets) print(f"{len(runner.datasets)} loaded") runner.run_model_evaluation(model_path) else: runner.load_results_from_file(results) # print(runner.parameters) # print(runner.results) eval = evaluator(runner.results) eval.evaluate_datasets(plot=True)