model robustness testing
This commit is contained in:
723
src/single-core-regen/tolerance_testing.py
Normal file
723
src/single-core-regen/tolerance_testing.py
Normal file
@@ -0,0 +1,723 @@
|
||||
"""
|
||||
tests a given model for tolerance against variations in
|
||||
- fiber length
|
||||
- baudrate
|
||||
- OSNR
|
||||
|
||||
CD, PMD, baudrate need different datasets, osnr is modeled as awgn added to the data before feeding into the model
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
import h5py
|
||||
import torch
|
||||
import util
|
||||
from hypertraining.settings import GlobalSettings, DataSettings, ModelSettings, OptimizerSettings, PytorchSettings
|
||||
from hypertraining import models
|
||||
|
||||
from signal_gen.generate_signal import single_run, get_config
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class NestedParameterIterator:
|
||||
def __init__(self, parameters):
|
||||
"""
|
||||
parameters: dict with key <param_name> and value <dict with keys "config" and "range">
|
||||
"""
|
||||
# self.parameters = parameters
|
||||
self.names = []
|
||||
self.ranges = []
|
||||
self.configs = []
|
||||
for k, v in parameters.items():
|
||||
self.names.append(k)
|
||||
self.ranges.append(v["range"])
|
||||
self.configs.append(v["config"])
|
||||
self.n_parameters = len(self.ranges)
|
||||
|
||||
self.idx = 0
|
||||
|
||||
self.range_idx = [0] * self.n_parameters
|
||||
self.range_len = [len(r) for r in self.ranges]
|
||||
self.length = int(np.prod(self.range_len))
|
||||
self.out = []
|
||||
|
||||
for i in range(self.length):
|
||||
self.out.append([])
|
||||
for j in range(self.n_parameters):
|
||||
element = {self.names[j]: {"value": self.ranges[j][self.range_idx[j]], "config": self.configs[j]}}
|
||||
self.out[i].append(element)
|
||||
self.range_idx[-1] += 1
|
||||
|
||||
# update range_idx back to front
|
||||
for j in range(self.n_parameters - 1, -1, -1):
|
||||
if self.range_idx[j] == self.range_len[j]:
|
||||
self.range_idx[j] = 0
|
||||
self.range_idx[j - 1] += 1
|
||||
|
||||
...
|
||||
|
||||
def __next__(self):
|
||||
if self.idx == self.length:
|
||||
raise StopIteration
|
||||
self.idx += 1
|
||||
return self.out[self.idx - 1]
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
|
||||
class model_runner:
|
||||
def __init__(
|
||||
self,
|
||||
# length_range: tuple[int | float] = (50e3, 50e3),
|
||||
# length_steps: int = 1,
|
||||
# length_log: bool = False,
|
||||
# baudrate_range: tuple[int | float] = (10e9, 10e9),
|
||||
# baudrate_steps: int = 1,
|
||||
# baudrate_log: bool = False,
|
||||
# osnr_range: tuple[int | float] = (16, 16),
|
||||
# osnr_steps: int = 1,
|
||||
# osnr_log: bool = False,
|
||||
# dataset_dir: str = "data",
|
||||
# dataset_datetime_glob: str = "*",
|
||||
results_dir: str = "tolerance_results/datasets",
|
||||
# model_dir: str = ".models",
|
||||
config: str = "signal_generation.ini",
|
||||
config_dir: str = None,
|
||||
debug: bool = False,
|
||||
):
|
||||
"""
|
||||
length_range: lower and upper limit of length, in meters
|
||||
length_step: step size of length, in meters
|
||||
baudrate_range: lower and upper limit of baudrate, in Bd
|
||||
baudrate_step: step size of baudrate, in Bd
|
||||
osnr_range: lower and upper limit of osnr, in dB
|
||||
osnr_step: step size of osnr, in dB
|
||||
dataset_dir: directory containing datasets
|
||||
dataset_datetime_glob: datetime glob pattern for dataset files
|
||||
results_dir: directory to save results
|
||||
model_dir: directory containing models
|
||||
"""
|
||||
self.debug = debug
|
||||
|
||||
self.parameters = {}
|
||||
self.iter = None
|
||||
|
||||
# self.update_length_range(length_range, length_steps, length_log)
|
||||
# self.update_baudrate_range(baudrate_range, baudrate_steps, baudrate_log)
|
||||
# self.update_osnr_range(osnr_range, osnr_steps, osnr_log)
|
||||
|
||||
# self.data_dir = Path(dataset_dir)
|
||||
# self.data_datetime_glob = dataset_datetime_glob
|
||||
self.results_dir = Path(results_dir)
|
||||
# self.model_dir = Path(model_dir)
|
||||
|
||||
config_dir = config_dir or Path(__file__).parent
|
||||
self.config = config_dir / config
|
||||
|
||||
torch.serialization.add_safe_globals([
|
||||
*util.complexNN.__all__,
|
||||
GlobalSettings,
|
||||
DataSettings,
|
||||
ModelSettings,
|
||||
OptimizerSettings,
|
||||
PytorchSettings,
|
||||
models.regenerator,
|
||||
torch.nn.utils.parametrizations.orthogonal,
|
||||
])
|
||||
|
||||
self.load_model()
|
||||
|
||||
self.datasets = []
|
||||
|
||||
# def register_parameter(self, name, config):
|
||||
# self.parameters.append({"name": name, "config": config})
|
||||
|
||||
def load_results_from_file(self, path):
|
||||
data, meta = self.load_from_file(path)
|
||||
self.results = [d.decode() for d in data]
|
||||
self.parameters = meta["parameters"]
|
||||
...
|
||||
|
||||
def load_datasets_from_file(self, path):
|
||||
data, meta = self.load_from_file(path)
|
||||
self.datasets = [d.decode() for d in data]
|
||||
self.parameters = meta["parameters"]
|
||||
...
|
||||
|
||||
def update_parameter_range(self, name, config, range, steps, log):
|
||||
self.parameters[name] = {"config": config, "range": self.update_range(*range, steps, log)}
|
||||
|
||||
def generate_iterations(self):
|
||||
if len(self.parameters) == 0:
|
||||
raise ValueError("No parameters registered")
|
||||
self.iter = NestedParameterIterator(self.parameters)
|
||||
|
||||
def generate_datasets(self):
|
||||
# get base config
|
||||
config = get_config(self.config)
|
||||
|
||||
if self.iter is None:
|
||||
self.generate_iterations()
|
||||
|
||||
for params in self.iter:
|
||||
current_settings = []
|
||||
# params is a list of dictionaries with keys "name", containing a dict with keys "value", "config"
|
||||
for param in params:
|
||||
for name, settings in param.items():
|
||||
current_settings.append({name: settings["value"]})
|
||||
self.nested_set(config, settings["config"], settings["value"])
|
||||
settings_strs = []
|
||||
for setting in current_settings:
|
||||
name = list(setting)[0]
|
||||
settings_strs.append(f"{name}: {float(setting[name]):.2e}")
|
||||
settings_str = ", ".join(settings_strs)
|
||||
print(f"Generating dataset for [{settings_str}]")
|
||||
# TODO look for existing datasets
|
||||
_, _, path = single_run(config)
|
||||
self.datasets.append(str(path))
|
||||
|
||||
datasets_list_path = self.build_path("datasets_list", parent_dir=self.results_dir, timestamp="back")
|
||||
metadata = {"parameters": self.parameters}
|
||||
data = np.array(self.datasets, dtype="S")
|
||||
self.save_to_file(datasets_list_path, data, **metadata)
|
||||
|
||||
@staticmethod
|
||||
def nested_set(dic, keys, value):
|
||||
for key in keys[:-1]:
|
||||
dic = dic.setdefault(key, {})
|
||||
dic[keys[-1]] = value
|
||||
|
||||
## Dataset and model loading
|
||||
# def find_datasets(self, data_dir=None, data_datetime_glob=None):
|
||||
# # date-time-sps-nos-length-gamma-alpha-D-S-PAM4-birefsteps-deltabeta-symbolrate.ini
|
||||
# data_dir = data_dir or self.data_dir
|
||||
# data_datetime_glob = data_datetime_glob or self.data_datetime_glob
|
||||
# self.datasets = {}
|
||||
# data_dir = Path(data_dir)
|
||||
# for length in self.lengths:
|
||||
# for baudrate in self.baudrates:
|
||||
# # dataset_glob = self.data_datetime_glob + f"*-*-{int(length)}-*-*-*-*-PAM4-*-*-{int(baudrate/1e9)}.ini"
|
||||
# dataset_glob = data_datetime_glob + f"-*-*-{int(length)}-*-*-*-*-PAM4-*-*.ini"
|
||||
# datasets = [f for f in data_dir.glob(dataset_glob)]
|
||||
# if len(datasets) == 0:
|
||||
# continue
|
||||
# self.datasets[length] = {}
|
||||
# if len(datasets) > 1:
|
||||
# print(
|
||||
# f"multiple datasets found for [{length / 1000:.1f} km, {int(baudrate / 1e9)} GBd]. Using the newest dataset."
|
||||
# )
|
||||
# # get newest file from creation date
|
||||
# datasets.sort(key=lambda x: x.stat().st_ctime)
|
||||
# self.datasets[length][baudrate] = str(datasets[-1])
|
||||
|
||||
def load_dataset(self, dataset_path):
|
||||
if self.checkpoint_dict is None:
|
||||
raise ValueError("Model must be loaded before dataset")
|
||||
|
||||
if self.dataset_path is None:
|
||||
self.dataset_path = dataset_path
|
||||
elif self.dataset_path == dataset_path:
|
||||
return
|
||||
|
||||
symbols = self.checkpoint_dict["settings"]["data_settings"].symbols
|
||||
data_size = self.checkpoint_dict["settings"]["data_settings"].output_size
|
||||
dtype = getattr(torch, self.checkpoint_dict["settings"]["data_settings"].dtype)
|
||||
drop_first = self.checkpoint_dict["settings"]["data_settings"].drop_first
|
||||
randomise_polarisations = self.checkpoint_dict["settings"]["data_settings"].randomise_polarisations
|
||||
polarisations = self.checkpoint_dict["settings"]["data_settings"].polarisations
|
||||
num_symbols = None
|
||||
if self.debug:
|
||||
num_symbols = 1000
|
||||
|
||||
config_path = Path(dataset_path)
|
||||
|
||||
dataset = util.datasets.FiberRegenerationDataset(
|
||||
file_path=config_path,
|
||||
symbols=symbols,
|
||||
output_dim=data_size,
|
||||
drop_first=drop_first,
|
||||
dtype=dtype,
|
||||
real=not dtype.is_complex,
|
||||
randomise_polarisations=randomise_polarisations,
|
||||
polarisations=polarisations,
|
||||
num_symbols=num_symbols,
|
||||
# device="cuda" if torch.cuda.is_available() else "cpu",
|
||||
)
|
||||
|
||||
self.dataloader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=2**14, pin_memory=True, num_workers=24, prefetch_factor=8, shuffle=False
|
||||
)
|
||||
|
||||
return self.dataloader.dataset.orig_symbols
|
||||
# run model
|
||||
# return results as array: [fiber_in, fiber_out, fiber_out_noisy, regen_out]
|
||||
|
||||
def load_model(self, model_path: str | None = None):
|
||||
if model_path is None:
|
||||
self.model = None
|
||||
self.model_path = None
|
||||
self.checkpoint_dict = None
|
||||
return
|
||||
|
||||
path = Path(model_path)
|
||||
|
||||
if self.model_path is None:
|
||||
self.model_path = path
|
||||
elif path == self.model_path:
|
||||
return
|
||||
|
||||
self.dataset_path = None # reset dataset path, as the shape depends on the model
|
||||
|
||||
self.checkpoint_dict = torch.load(path, weights_only=True)
|
||||
dims = self.checkpoint_dict["model_kwargs"].pop("dims")
|
||||
self.model = models.regenerator(*dims, **self.checkpoint_dict["model_kwargs"])
|
||||
self.model.load_state_dict(self.checkpoint_dict["model_state_dict"])
|
||||
|
||||
## Model evaluation
|
||||
def run_model_evaluation(self, model_path: str, datasets: str | None = None):
|
||||
self.load_model(model_path)
|
||||
# iterate over datasets and osnr values:
|
||||
# load dataset, add noise, run model, return results
|
||||
# save results to file
|
||||
self.results = []
|
||||
|
||||
if datasets is not None:
|
||||
self.load_datasets_from_file(datasets)
|
||||
|
||||
n_datasets = len(self.datasets)
|
||||
for i, dataset_path in enumerate(self.datasets):
|
||||
conf = get_config(dataset_path)
|
||||
mpath = Path(model_path)
|
||||
model_base = mpath.stem
|
||||
print(f"({1+i}/{n_datasets}) Running model {model_base} with dataset {dataset_path.split('/')[-1]}")
|
||||
|
||||
results_path = self.build_path(
|
||||
dataset_path.split("/")[-1], parent_dir=Path(self.results_dir) / model_base
|
||||
)
|
||||
|
||||
orig_symbols = self.load_dataset(dataset_path)
|
||||
|
||||
data, loss = self.run_model()
|
||||
|
||||
metadata = {
|
||||
"model_path": model_path,
|
||||
"dataset_path": dataset_path,
|
||||
"loss": loss,
|
||||
"sps": conf["glova"]["sps"],
|
||||
"orig_symbols": orig_symbols
|
||||
# "config": conf,
|
||||
# "checkpoint_dict": self.checkpoint_dict,
|
||||
# "nos": self.dataloader.dataset.num_symbols,
|
||||
}
|
||||
|
||||
self.save_to_file(results_path, data, **metadata)
|
||||
self.results.append(str(results_path))
|
||||
|
||||
results_list_path = self.build_path("results_list", parent_dir=self.results_dir, timestamp="back")
|
||||
metadata = {"parameters": self.parameters}
|
||||
data = np.array(self.results, dtype="S")
|
||||
self.save_to_file(results_list_path, data, **metadata)
|
||||
|
||||
def run_model(self):
|
||||
loss = 0
|
||||
datas = []
|
||||
|
||||
self.model.eval()
|
||||
model = self.model.to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
with torch.no_grad():
|
||||
for batch in self.dataloader:
|
||||
x = batch["x_stacked"]
|
||||
y = batch["y_stacked"]
|
||||
fiber_in = batch["plot_target"]
|
||||
# fiber_out = batch["plot_clean"]
|
||||
fiber_out = batch["plot_data"]
|
||||
timestamp = batch["timestamp"]
|
||||
angle = batch["mean_angle"]
|
||||
x = x.to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
angle = angle.to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
regen = model(x, -angle)
|
||||
regen = regen.to("cpu")
|
||||
loss += util.complexNN.complex_mse_loss(regen, y, power=True).item()
|
||||
# shape: [batch_size, 4]
|
||||
plot_regen = regen[:, :2]
|
||||
plot_regen = plot_regen.view(plot_regen.shape[0], -1, 2)
|
||||
plot_regen = plot_regen[:, plot_regen.shape[1] // 2, :]
|
||||
|
||||
data_out = torch.cat(
|
||||
(
|
||||
fiber_in,
|
||||
fiber_out,
|
||||
# fiber_out_noisy,
|
||||
plot_regen,
|
||||
timestamp.view(-1, 1),
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
datas.append(data_out)
|
||||
|
||||
data_out = torch.cat(datas, dim=0).numpy()
|
||||
|
||||
return data_out, loss
|
||||
|
||||
## File I/O
|
||||
@staticmethod
|
||||
def save_to_file(path: str, data: np.ndarray, **metadata: dict):
|
||||
# create directory if it doesn't exist
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with h5py.File(path, "w") as outfile:
|
||||
outfile.create_dataset("data", data=data)
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, dict):
|
||||
value = json.dumps(model_runner.convert_arrays(value))
|
||||
outfile.attrs[key] = value
|
||||
|
||||
@staticmethod
|
||||
def convert_arrays(dict_in):
|
||||
"""
|
||||
convert ndarrays in (nested) dict to lists
|
||||
"""
|
||||
dict_out = {}
|
||||
for key, value in dict_in.items():
|
||||
if isinstance(value, dict):
|
||||
dict_out[key] = model_runner.convert_arrays(value)
|
||||
elif isinstance(value, np.ndarray):
|
||||
dict_out[key] = value.tolist()
|
||||
else:
|
||||
dict_out[key] = value
|
||||
return dict_out
|
||||
|
||||
@staticmethod
|
||||
def load_from_file(path: str):
|
||||
with h5py.File(path, "r") as infile:
|
||||
data = infile["data"][:]
|
||||
metadata = {}
|
||||
for key in infile.attrs.keys():
|
||||
if isinstance(infile.attrs[key], (str, bytes, bytearray)):
|
||||
try:
|
||||
metadata[key] = json.loads(infile.attrs[key])
|
||||
except json.JSONDecodeError:
|
||||
metadata[key] = infile.attrs[key]
|
||||
else:
|
||||
metadata[key] = infile.attrs[key]
|
||||
return data, metadata
|
||||
|
||||
## Utility functions
|
||||
@staticmethod
|
||||
def logrange(start, stop, num, endpoint=False):
|
||||
lower, upper = np.log10((start, stop))
|
||||
return np.logspace(lower, upper, num=num, endpoint=endpoint, base=10)
|
||||
|
||||
@staticmethod
|
||||
def build_path(
|
||||
*elements, parent_dir: str | Path | None = None, filetype="h5", timestamp: Literal["no", "front", "back"] = "no"
|
||||
):
|
||||
suffix = f".{filetype}" if not filetype.startswith(".") else filetype
|
||||
if timestamp != "no":
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
if timestamp == "front":
|
||||
elements = (ts, *elements)
|
||||
elif timestamp == "back":
|
||||
elements = (*elements, ts)
|
||||
path = "_".join(elements)
|
||||
path += suffix
|
||||
|
||||
if parent_dir is not None:
|
||||
path = Path(parent_dir) / path
|
||||
return path
|
||||
|
||||
@staticmethod
|
||||
def update_range(min, max, n_steps, log):
|
||||
if log:
|
||||
range = model_runner.logrange(min, max, n_steps, endpoint=True)
|
||||
else:
|
||||
range = np.linspace(min, max, n_steps, endpoint=True)
|
||||
return range
|
||||
|
||||
|
||||
class model_evaluation_result:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
length=None,
|
||||
baudrate=None,
|
||||
osnr=None,
|
||||
model_path=None,
|
||||
dataset_path=None,
|
||||
loss=None,
|
||||
sps=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.length = length
|
||||
self.baudrate = baudrate
|
||||
self.osnr = osnr
|
||||
self.model_path = model_path
|
||||
self.dataset_path = dataset_path
|
||||
self.loss = loss
|
||||
self.sps = sps
|
||||
|
||||
self.sers = None
|
||||
self.bers = None
|
||||
self.eye_stats = None
|
||||
|
||||
|
||||
class evaluator:
|
||||
def __init__(self, datasets: list[str]):
|
||||
"""
|
||||
datasets: iterable of dataset paths
|
||||
data_dir: directory containing datasets
|
||||
"""
|
||||
self.datasets = datasets
|
||||
self.results = []
|
||||
|
||||
def evaluate_datasets(self, plot=False):
|
||||
## iterate over datasets
|
||||
# load dataset
|
||||
for dataset in self.datasets:
|
||||
model, dataset_name = dataset.split("/")[-2:]
|
||||
print(f"\nEvaluating model {model} with dataset {dataset_name}")
|
||||
data, metadata = model_runner.load_from_file(dataset)
|
||||
result = model_evaluation_result(**metadata)
|
||||
|
||||
data = self.prepare_data(data, sps=metadata["sps"])
|
||||
|
||||
try:
|
||||
sym_x, sym_y = metadata["orig_symbols"]
|
||||
except (TypeError, KeyError, ValueError):
|
||||
sym_x, sym_y = None, None
|
||||
|
||||
self.evaluate_eye(data, result, title=dataset.split("/")[-1], plot=False)
|
||||
self.evaluate_ser_ber(data, result, sym_x, sym_y)
|
||||
print("BER:")
|
||||
self.print_dict(result.bers["regen"])
|
||||
print()
|
||||
print("SER:")
|
||||
self.print_dict(result.sers["regen"])
|
||||
print()
|
||||
|
||||
self.results.append(result)
|
||||
if plot:
|
||||
plt.show()
|
||||
|
||||
|
||||
def evaluate_eye(self, data, result, title=None, plot=False):
|
||||
eye = util.eye_diagram.eye_diagram(
|
||||
data,
|
||||
channel_names=[
|
||||
"fiber_in_x",
|
||||
"fiber_in_y",
|
||||
# "fiber_out_x",
|
||||
# "fiber_out_y",
|
||||
"fiber_out_x",
|
||||
"fiber_out_y",
|
||||
"regen_x",
|
||||
"regen_y",
|
||||
],
|
||||
)
|
||||
|
||||
eye.analyse()
|
||||
eye.plot(title=title or "Eye diagram", show=plot)
|
||||
|
||||
result.eye_stats = eye.eye_stats
|
||||
return eye.eye_stats
|
||||
...
|
||||
|
||||
def evaluate_ser_ber(self, data, result, sym_x=None, sym_y=None):
|
||||
if result.eye_stats is None:
|
||||
self.evaluate_eye(data, result)
|
||||
|
||||
symbols = []
|
||||
sers = {"fiber_out": {"x": None, "y": None}, "regen": {"x": None, "y": None}}
|
||||
bers = {"fiber_out": {"x": None, "y": None}, "regen": {"x": None, "y": None}}
|
||||
|
||||
for channel_data, stats in zip(data, result.eye_stats):
|
||||
timestamps = channel_data[0]
|
||||
dat = channel_data[1]
|
||||
|
||||
channel_name = stats["channel_name"]
|
||||
if stats["success"]:
|
||||
thresholds = stats["thresholds"]
|
||||
time_midpoint = stats["time_midpoint"]
|
||||
else:
|
||||
if channel_name.endswith("x"):
|
||||
thresholds = result.eye_stats[0]["thresholds"]
|
||||
time_midpoint = result.eye_stats[0]["time_midpoint"]
|
||||
elif channel_name.endswith("y"):
|
||||
thresholds = result.eye_stats[1]["thresholds"]
|
||||
time_midpoint = result.eye_stats[1]["time_midpoint"]
|
||||
else:
|
||||
levels = np.linspace(np.min(dat), np.max(dat), 4)
|
||||
thresholds = util.eye_diagram.eye_diagram.calculate_thresholds(levels)
|
||||
time_midpoint = 1.0
|
||||
|
||||
# time_offset = time_midpoint - 0.5
|
||||
# # time_offset = 0
|
||||
|
||||
# index_offset = np.argmin(np.abs((timestamps - time_offset) % 1.0))
|
||||
|
||||
nos = len(timestamps) // result.sps
|
||||
|
||||
# idx = np.arange(index_offset, len(timestamps), result.sps).astype(int)
|
||||
|
||||
# if time_offset < 0:
|
||||
# idx = np.insert(idx, 0, 0)
|
||||
|
||||
idx = list(range(0,len(timestamps),result.sps))
|
||||
|
||||
idx = idx[:nos]
|
||||
|
||||
data_sampled = dat[idx]
|
||||
detected_symbols = self.detect_symbols(data_sampled, thresholds)
|
||||
|
||||
symbols.append({"channel_name": channel_name, "symbols": detected_symbols})
|
||||
|
||||
symbols_x_gt = sym_x or symbols[0]["symbols"]
|
||||
symbols_y_gt = sym_y or symbols[1]["symbols"]
|
||||
|
||||
symbols_x_fiber_out = symbols[2]["symbols"]
|
||||
symbols_y_fiber_out = symbols[3]["symbols"]
|
||||
|
||||
symbols_x_regen = symbols[4]["symbols"]
|
||||
symbols_y_regen = symbols[5]["symbols"]
|
||||
|
||||
sers["fiber_out"]["x"], bers["fiber_out"]["x"] = self.calculate_ser_ber(symbols_x_gt, symbols_x_fiber_out)
|
||||
sers["fiber_out"]["y"], bers["fiber_out"]["y"] = self.calculate_ser_ber(symbols_y_gt, symbols_y_fiber_out)
|
||||
sers["regen"]["x"], bers["regen"]["x"] = self.calculate_ser_ber(symbols_x_gt, symbols_x_regen)
|
||||
sers["regen"]["y"], bers["regen"]["y"] = self.calculate_ser_ber(symbols_y_gt, symbols_y_regen)
|
||||
|
||||
result.sers = sers
|
||||
result.bers = bers
|
||||
|
||||
@staticmethod
|
||||
def calculate_ser_ber(symbols_gt, symbols):
|
||||
# levels = 4
|
||||
# symbol difference -> bit error count
|
||||
# |rx - tx| = 0 -> 0
|
||||
# |rx - tx| = 1 -> 1
|
||||
# |rx - tx| = 2 -> 2
|
||||
# |rx - tx| = 3 -> 1
|
||||
# assuming gray coding -> 0: 00, 1: 01, 2: 11, 3: 10
|
||||
bec_map = {0: 0, 1: 1, 2: 2, 3: 1, np.nan: 2}
|
||||
|
||||
ser = {}
|
||||
ber = {}
|
||||
ser["n_symbols"] = len(symbols_gt)
|
||||
ser["n_errors"] = np.sum(symbols != symbols_gt)
|
||||
ser["total"] = float(ser["n_errors"] / ser["n_symbols"])
|
||||
|
||||
bec = np.vectorize(bec_map.get)(np.abs(symbols - symbols_gt))
|
||||
bit_errors = np.sum(bec)
|
||||
|
||||
ber["n_bits"] = len(symbols_gt) * 2
|
||||
ber["n_errors"] = bit_errors
|
||||
ber["total"] = float(ber["n_errors"] / ber["n_bits"])
|
||||
|
||||
return ser, ber
|
||||
|
||||
@staticmethod
|
||||
def print_dict(d: dict, indent=2, logarithmic=False, level=0):
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
print(f"{' ' * indent * level}{key}:")
|
||||
evaluator.print_dict(value, indent=indent, logarithmic=logarithmic, level=level + 1)
|
||||
else:
|
||||
if isinstance(value, float):
|
||||
if logarithmic:
|
||||
if value == 0:
|
||||
value = -np.inf
|
||||
else:
|
||||
value = np.log10(value)
|
||||
print(f"{' ' * indent * level}{key}: {value:.2e}\t", end="")
|
||||
else:
|
||||
print(f"{' ' * indent * level}{key}: {value}\t", end="")
|
||||
print()
|
||||
|
||||
@staticmethod
|
||||
def detect_symbols(samples, thresholds=None):
|
||||
thresholds = (1 / 6, 3 / 6, 5 / 6) if thresholds is None else thresholds
|
||||
thresholds = (-np.inf, *thresholds, np.inf)
|
||||
symbols = np.digitize(samples, thresholds) - 1
|
||||
return symbols
|
||||
|
||||
@staticmethod
|
||||
def prepare_data(data, sps=None):
|
||||
data = data.transpose(1, 0)
|
||||
timestamps = data[-1].real
|
||||
data = data[:-1]
|
||||
if sps is not None:
|
||||
timestamps /= sps
|
||||
|
||||
# data = np.stack(
|
||||
# (
|
||||
# *data[0:2], # fiber_in_x, fiber_in_y
|
||||
# # *data_[2:4], # fiber_out_x, fiber_out_y
|
||||
# *data[4:6], # fiber_out_noisy_x, fiber_out_noisy_y
|
||||
# *data[6:8], # regen_out_x, regen_out_y
|
||||
# ),
|
||||
# axis=0,
|
||||
# )
|
||||
|
||||
data_eye = []
|
||||
for channel_values in data:
|
||||
channel_values = np.square(np.abs(channel_values))
|
||||
data_eye.append(np.stack((timestamps, channel_values), axis=0))
|
||||
|
||||
data_eye = np.stack(data_eye, axis=0)
|
||||
|
||||
return data_eye
|
||||
|
||||
|
||||
def generate_data(parameters, runner=None):
|
||||
runner = runner or model_runner()
|
||||
for param in parameters:
|
||||
runner.update_parameter_range(*param)
|
||||
runner.generate_iterations()
|
||||
print(f"{runner.iter.length} parameter combinations")
|
||||
runner.generate_datasets()
|
||||
|
||||
return runner
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_path = ".models/best_20250110_191149.tar" # D 17, OSNR 100, delta_beta 0.14, baud 10e9
|
||||
|
||||
parameters = (
|
||||
# name, config keys, (min, max), n_steps, log
|
||||
# ("D", ("fiber", "d"), (28,30), 3, False),
|
||||
# ("S", ("fiber", "s"), (0, 0.058), 2, False),
|
||||
("OSNR", ("signal", "osnr"), (20, 40), 5, False),
|
||||
# ("PMD", ("fiber", "max_delta_beta"), (0, 0.28), 3, False),
|
||||
# ("Baud", ("glova", "symbolrate"), (10e9, 100e9), 3, True),
|
||||
)
|
||||
|
||||
datasets = None
|
||||
results = None
|
||||
|
||||
# datasets = "tolerance_results/datasets/datasets_list_20250110_223337.h5"
|
||||
results = "tolerance_results/datasets/results_list_20250110_232639.h5"
|
||||
|
||||
runner = model_runner()
|
||||
# generate_data(parameters, runner)
|
||||
|
||||
|
||||
if results is None:
|
||||
if datasets is None:
|
||||
generate_data(parameters, runner)
|
||||
else:
|
||||
runner.load_datasets_from_file(datasets)
|
||||
print(f"{len(runner.datasets)} loaded")
|
||||
|
||||
runner.run_model_evaluation(model_path)
|
||||
else:
|
||||
runner.load_results_from_file(results)
|
||||
|
||||
# print(runner.parameters)
|
||||
# print(runner.results)
|
||||
|
||||
eval = evaluator(runner.results)
|
||||
eval.evaluate_datasets(plot=True)
|
||||
Reference in New Issue
Block a user