model robustness testing

This commit is contained in:
Joseph Hopfmüller
2025-01-10 23:40:54 +01:00
parent 3af73343c1
commit f38d0ca3bb
13 changed files with 1558 additions and 334 deletions

View File

@@ -0,0 +1,723 @@
"""
tests a given model for tolerance against variations in
- fiber length
- baudrate
- OSNR
CD, PMD, baudrate need different datasets, osnr is modeled as awgn added to the data before feeding into the model
"""
from datetime import datetime
from typing import Literal
from matplotlib import pyplot as plt
import numpy as np
from pathlib import Path
import h5py
import torch
import util
from hypertraining.settings import GlobalSettings, DataSettings, ModelSettings, OptimizerSettings, PytorchSettings
from hypertraining import models
from signal_gen.generate_signal import single_run, get_config
import json
class NestedParameterIterator:
def __init__(self, parameters):
"""
parameters: dict with key <param_name> and value <dict with keys "config" and "range">
"""
# self.parameters = parameters
self.names = []
self.ranges = []
self.configs = []
for k, v in parameters.items():
self.names.append(k)
self.ranges.append(v["range"])
self.configs.append(v["config"])
self.n_parameters = len(self.ranges)
self.idx = 0
self.range_idx = [0] * self.n_parameters
self.range_len = [len(r) for r in self.ranges]
self.length = int(np.prod(self.range_len))
self.out = []
for i in range(self.length):
self.out.append([])
for j in range(self.n_parameters):
element = {self.names[j]: {"value": self.ranges[j][self.range_idx[j]], "config": self.configs[j]}}
self.out[i].append(element)
self.range_idx[-1] += 1
# update range_idx back to front
for j in range(self.n_parameters - 1, -1, -1):
if self.range_idx[j] == self.range_len[j]:
self.range_idx[j] = 0
self.range_idx[j - 1] += 1
...
def __next__(self):
if self.idx == self.length:
raise StopIteration
self.idx += 1
return self.out[self.idx - 1]
def __iter__(self):
return self
class model_runner:
def __init__(
self,
# length_range: tuple[int | float] = (50e3, 50e3),
# length_steps: int = 1,
# length_log: bool = False,
# baudrate_range: tuple[int | float] = (10e9, 10e9),
# baudrate_steps: int = 1,
# baudrate_log: bool = False,
# osnr_range: tuple[int | float] = (16, 16),
# osnr_steps: int = 1,
# osnr_log: bool = False,
# dataset_dir: str = "data",
# dataset_datetime_glob: str = "*",
results_dir: str = "tolerance_results/datasets",
# model_dir: str = ".models",
config: str = "signal_generation.ini",
config_dir: str = None,
debug: bool = False,
):
"""
length_range: lower and upper limit of length, in meters
length_step: step size of length, in meters
baudrate_range: lower and upper limit of baudrate, in Bd
baudrate_step: step size of baudrate, in Bd
osnr_range: lower and upper limit of osnr, in dB
osnr_step: step size of osnr, in dB
dataset_dir: directory containing datasets
dataset_datetime_glob: datetime glob pattern for dataset files
results_dir: directory to save results
model_dir: directory containing models
"""
self.debug = debug
self.parameters = {}
self.iter = None
# self.update_length_range(length_range, length_steps, length_log)
# self.update_baudrate_range(baudrate_range, baudrate_steps, baudrate_log)
# self.update_osnr_range(osnr_range, osnr_steps, osnr_log)
# self.data_dir = Path(dataset_dir)
# self.data_datetime_glob = dataset_datetime_glob
self.results_dir = Path(results_dir)
# self.model_dir = Path(model_dir)
config_dir = config_dir or Path(__file__).parent
self.config = config_dir / config
torch.serialization.add_safe_globals([
*util.complexNN.__all__,
GlobalSettings,
DataSettings,
ModelSettings,
OptimizerSettings,
PytorchSettings,
models.regenerator,
torch.nn.utils.parametrizations.orthogonal,
])
self.load_model()
self.datasets = []
# def register_parameter(self, name, config):
# self.parameters.append({"name": name, "config": config})
def load_results_from_file(self, path):
data, meta = self.load_from_file(path)
self.results = [d.decode() for d in data]
self.parameters = meta["parameters"]
...
def load_datasets_from_file(self, path):
data, meta = self.load_from_file(path)
self.datasets = [d.decode() for d in data]
self.parameters = meta["parameters"]
...
def update_parameter_range(self, name, config, range, steps, log):
self.parameters[name] = {"config": config, "range": self.update_range(*range, steps, log)}
def generate_iterations(self):
if len(self.parameters) == 0:
raise ValueError("No parameters registered")
self.iter = NestedParameterIterator(self.parameters)
def generate_datasets(self):
# get base config
config = get_config(self.config)
if self.iter is None:
self.generate_iterations()
for params in self.iter:
current_settings = []
# params is a list of dictionaries with keys "name", containing a dict with keys "value", "config"
for param in params:
for name, settings in param.items():
current_settings.append({name: settings["value"]})
self.nested_set(config, settings["config"], settings["value"])
settings_strs = []
for setting in current_settings:
name = list(setting)[0]
settings_strs.append(f"{name}: {float(setting[name]):.2e}")
settings_str = ", ".join(settings_strs)
print(f"Generating dataset for [{settings_str}]")
# TODO look for existing datasets
_, _, path = single_run(config)
self.datasets.append(str(path))
datasets_list_path = self.build_path("datasets_list", parent_dir=self.results_dir, timestamp="back")
metadata = {"parameters": self.parameters}
data = np.array(self.datasets, dtype="S")
self.save_to_file(datasets_list_path, data, **metadata)
@staticmethod
def nested_set(dic, keys, value):
for key in keys[:-1]:
dic = dic.setdefault(key, {})
dic[keys[-1]] = value
## Dataset and model loading
# def find_datasets(self, data_dir=None, data_datetime_glob=None):
# # date-time-sps-nos-length-gamma-alpha-D-S-PAM4-birefsteps-deltabeta-symbolrate.ini
# data_dir = data_dir or self.data_dir
# data_datetime_glob = data_datetime_glob or self.data_datetime_glob
# self.datasets = {}
# data_dir = Path(data_dir)
# for length in self.lengths:
# for baudrate in self.baudrates:
# # dataset_glob = self.data_datetime_glob + f"*-*-{int(length)}-*-*-*-*-PAM4-*-*-{int(baudrate/1e9)}.ini"
# dataset_glob = data_datetime_glob + f"-*-*-{int(length)}-*-*-*-*-PAM4-*-*.ini"
# datasets = [f for f in data_dir.glob(dataset_glob)]
# if len(datasets) == 0:
# continue
# self.datasets[length] = {}
# if len(datasets) > 1:
# print(
# f"multiple datasets found for [{length / 1000:.1f} km, {int(baudrate / 1e9)} GBd]. Using the newest dataset."
# )
# # get newest file from creation date
# datasets.sort(key=lambda x: x.stat().st_ctime)
# self.datasets[length][baudrate] = str(datasets[-1])
def load_dataset(self, dataset_path):
if self.checkpoint_dict is None:
raise ValueError("Model must be loaded before dataset")
if self.dataset_path is None:
self.dataset_path = dataset_path
elif self.dataset_path == dataset_path:
return
symbols = self.checkpoint_dict["settings"]["data_settings"].symbols
data_size = self.checkpoint_dict["settings"]["data_settings"].output_size
dtype = getattr(torch, self.checkpoint_dict["settings"]["data_settings"].dtype)
drop_first = self.checkpoint_dict["settings"]["data_settings"].drop_first
randomise_polarisations = self.checkpoint_dict["settings"]["data_settings"].randomise_polarisations
polarisations = self.checkpoint_dict["settings"]["data_settings"].polarisations
num_symbols = None
if self.debug:
num_symbols = 1000
config_path = Path(dataset_path)
dataset = util.datasets.FiberRegenerationDataset(
file_path=config_path,
symbols=symbols,
output_dim=data_size,
drop_first=drop_first,
dtype=dtype,
real=not dtype.is_complex,
randomise_polarisations=randomise_polarisations,
polarisations=polarisations,
num_symbols=num_symbols,
# device="cuda" if torch.cuda.is_available() else "cpu",
)
self.dataloader = torch.utils.data.DataLoader(
dataset, batch_size=2**14, pin_memory=True, num_workers=24, prefetch_factor=8, shuffle=False
)
return self.dataloader.dataset.orig_symbols
# run model
# return results as array: [fiber_in, fiber_out, fiber_out_noisy, regen_out]
def load_model(self, model_path: str | None = None):
if model_path is None:
self.model = None
self.model_path = None
self.checkpoint_dict = None
return
path = Path(model_path)
if self.model_path is None:
self.model_path = path
elif path == self.model_path:
return
self.dataset_path = None # reset dataset path, as the shape depends on the model
self.checkpoint_dict = torch.load(path, weights_only=True)
dims = self.checkpoint_dict["model_kwargs"].pop("dims")
self.model = models.regenerator(*dims, **self.checkpoint_dict["model_kwargs"])
self.model.load_state_dict(self.checkpoint_dict["model_state_dict"])
## Model evaluation
def run_model_evaluation(self, model_path: str, datasets: str | None = None):
self.load_model(model_path)
# iterate over datasets and osnr values:
# load dataset, add noise, run model, return results
# save results to file
self.results = []
if datasets is not None:
self.load_datasets_from_file(datasets)
n_datasets = len(self.datasets)
for i, dataset_path in enumerate(self.datasets):
conf = get_config(dataset_path)
mpath = Path(model_path)
model_base = mpath.stem
print(f"({1+i}/{n_datasets}) Running model {model_base} with dataset {dataset_path.split('/')[-1]}")
results_path = self.build_path(
dataset_path.split("/")[-1], parent_dir=Path(self.results_dir) / model_base
)
orig_symbols = self.load_dataset(dataset_path)
data, loss = self.run_model()
metadata = {
"model_path": model_path,
"dataset_path": dataset_path,
"loss": loss,
"sps": conf["glova"]["sps"],
"orig_symbols": orig_symbols
# "config": conf,
# "checkpoint_dict": self.checkpoint_dict,
# "nos": self.dataloader.dataset.num_symbols,
}
self.save_to_file(results_path, data, **metadata)
self.results.append(str(results_path))
results_list_path = self.build_path("results_list", parent_dir=self.results_dir, timestamp="back")
metadata = {"parameters": self.parameters}
data = np.array(self.results, dtype="S")
self.save_to_file(results_list_path, data, **metadata)
def run_model(self):
loss = 0
datas = []
self.model.eval()
model = self.model.to("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
for batch in self.dataloader:
x = batch["x_stacked"]
y = batch["y_stacked"]
fiber_in = batch["plot_target"]
# fiber_out = batch["plot_clean"]
fiber_out = batch["plot_data"]
timestamp = batch["timestamp"]
angle = batch["mean_angle"]
x = x.to("cuda" if torch.cuda.is_available() else "cpu")
angle = angle.to("cuda" if torch.cuda.is_available() else "cpu")
regen = model(x, -angle)
regen = regen.to("cpu")
loss += util.complexNN.complex_mse_loss(regen, y, power=True).item()
# shape: [batch_size, 4]
plot_regen = regen[:, :2]
plot_regen = plot_regen.view(plot_regen.shape[0], -1, 2)
plot_regen = plot_regen[:, plot_regen.shape[1] // 2, :]
data_out = torch.cat(
(
fiber_in,
fiber_out,
# fiber_out_noisy,
plot_regen,
timestamp.view(-1, 1),
),
dim=1,
)
datas.append(data_out)
data_out = torch.cat(datas, dim=0).numpy()
return data_out, loss
## File I/O
@staticmethod
def save_to_file(path: str, data: np.ndarray, **metadata: dict):
# create directory if it doesn't exist
path.parent.mkdir(parents=True, exist_ok=True)
with h5py.File(path, "w") as outfile:
outfile.create_dataset("data", data=data)
for key, value in metadata.items():
if isinstance(value, dict):
value = json.dumps(model_runner.convert_arrays(value))
outfile.attrs[key] = value
@staticmethod
def convert_arrays(dict_in):
"""
convert ndarrays in (nested) dict to lists
"""
dict_out = {}
for key, value in dict_in.items():
if isinstance(value, dict):
dict_out[key] = model_runner.convert_arrays(value)
elif isinstance(value, np.ndarray):
dict_out[key] = value.tolist()
else:
dict_out[key] = value
return dict_out
@staticmethod
def load_from_file(path: str):
with h5py.File(path, "r") as infile:
data = infile["data"][:]
metadata = {}
for key in infile.attrs.keys():
if isinstance(infile.attrs[key], (str, bytes, bytearray)):
try:
metadata[key] = json.loads(infile.attrs[key])
except json.JSONDecodeError:
metadata[key] = infile.attrs[key]
else:
metadata[key] = infile.attrs[key]
return data, metadata
## Utility functions
@staticmethod
def logrange(start, stop, num, endpoint=False):
lower, upper = np.log10((start, stop))
return np.logspace(lower, upper, num=num, endpoint=endpoint, base=10)
@staticmethod
def build_path(
*elements, parent_dir: str | Path | None = None, filetype="h5", timestamp: Literal["no", "front", "back"] = "no"
):
suffix = f".{filetype}" if not filetype.startswith(".") else filetype
if timestamp != "no":
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
if timestamp == "front":
elements = (ts, *elements)
elif timestamp == "back":
elements = (*elements, ts)
path = "_".join(elements)
path += suffix
if parent_dir is not None:
path = Path(parent_dir) / path
return path
@staticmethod
def update_range(min, max, n_steps, log):
if log:
range = model_runner.logrange(min, max, n_steps, endpoint=True)
else:
range = np.linspace(min, max, n_steps, endpoint=True)
return range
class model_evaluation_result:
def __init__(
self,
*,
length=None,
baudrate=None,
osnr=None,
model_path=None,
dataset_path=None,
loss=None,
sps=None,
**kwargs,
):
self.length = length
self.baudrate = baudrate
self.osnr = osnr
self.model_path = model_path
self.dataset_path = dataset_path
self.loss = loss
self.sps = sps
self.sers = None
self.bers = None
self.eye_stats = None
class evaluator:
def __init__(self, datasets: list[str]):
"""
datasets: iterable of dataset paths
data_dir: directory containing datasets
"""
self.datasets = datasets
self.results = []
def evaluate_datasets(self, plot=False):
## iterate over datasets
# load dataset
for dataset in self.datasets:
model, dataset_name = dataset.split("/")[-2:]
print(f"\nEvaluating model {model} with dataset {dataset_name}")
data, metadata = model_runner.load_from_file(dataset)
result = model_evaluation_result(**metadata)
data = self.prepare_data(data, sps=metadata["sps"])
try:
sym_x, sym_y = metadata["orig_symbols"]
except (TypeError, KeyError, ValueError):
sym_x, sym_y = None, None
self.evaluate_eye(data, result, title=dataset.split("/")[-1], plot=False)
self.evaluate_ser_ber(data, result, sym_x, sym_y)
print("BER:")
self.print_dict(result.bers["regen"])
print()
print("SER:")
self.print_dict(result.sers["regen"])
print()
self.results.append(result)
if plot:
plt.show()
def evaluate_eye(self, data, result, title=None, plot=False):
eye = util.eye_diagram.eye_diagram(
data,
channel_names=[
"fiber_in_x",
"fiber_in_y",
# "fiber_out_x",
# "fiber_out_y",
"fiber_out_x",
"fiber_out_y",
"regen_x",
"regen_y",
],
)
eye.analyse()
eye.plot(title=title or "Eye diagram", show=plot)
result.eye_stats = eye.eye_stats
return eye.eye_stats
...
def evaluate_ser_ber(self, data, result, sym_x=None, sym_y=None):
if result.eye_stats is None:
self.evaluate_eye(data, result)
symbols = []
sers = {"fiber_out": {"x": None, "y": None}, "regen": {"x": None, "y": None}}
bers = {"fiber_out": {"x": None, "y": None}, "regen": {"x": None, "y": None}}
for channel_data, stats in zip(data, result.eye_stats):
timestamps = channel_data[0]
dat = channel_data[1]
channel_name = stats["channel_name"]
if stats["success"]:
thresholds = stats["thresholds"]
time_midpoint = stats["time_midpoint"]
else:
if channel_name.endswith("x"):
thresholds = result.eye_stats[0]["thresholds"]
time_midpoint = result.eye_stats[0]["time_midpoint"]
elif channel_name.endswith("y"):
thresholds = result.eye_stats[1]["thresholds"]
time_midpoint = result.eye_stats[1]["time_midpoint"]
else:
levels = np.linspace(np.min(dat), np.max(dat), 4)
thresholds = util.eye_diagram.eye_diagram.calculate_thresholds(levels)
time_midpoint = 1.0
# time_offset = time_midpoint - 0.5
# # time_offset = 0
# index_offset = np.argmin(np.abs((timestamps - time_offset) % 1.0))
nos = len(timestamps) // result.sps
# idx = np.arange(index_offset, len(timestamps), result.sps).astype(int)
# if time_offset < 0:
# idx = np.insert(idx, 0, 0)
idx = list(range(0,len(timestamps),result.sps))
idx = idx[:nos]
data_sampled = dat[idx]
detected_symbols = self.detect_symbols(data_sampled, thresholds)
symbols.append({"channel_name": channel_name, "symbols": detected_symbols})
symbols_x_gt = sym_x or symbols[0]["symbols"]
symbols_y_gt = sym_y or symbols[1]["symbols"]
symbols_x_fiber_out = symbols[2]["symbols"]
symbols_y_fiber_out = symbols[3]["symbols"]
symbols_x_regen = symbols[4]["symbols"]
symbols_y_regen = symbols[5]["symbols"]
sers["fiber_out"]["x"], bers["fiber_out"]["x"] = self.calculate_ser_ber(symbols_x_gt, symbols_x_fiber_out)
sers["fiber_out"]["y"], bers["fiber_out"]["y"] = self.calculate_ser_ber(symbols_y_gt, symbols_y_fiber_out)
sers["regen"]["x"], bers["regen"]["x"] = self.calculate_ser_ber(symbols_x_gt, symbols_x_regen)
sers["regen"]["y"], bers["regen"]["y"] = self.calculate_ser_ber(symbols_y_gt, symbols_y_regen)
result.sers = sers
result.bers = bers
@staticmethod
def calculate_ser_ber(symbols_gt, symbols):
# levels = 4
# symbol difference -> bit error count
# |rx - tx| = 0 -> 0
# |rx - tx| = 1 -> 1
# |rx - tx| = 2 -> 2
# |rx - tx| = 3 -> 1
# assuming gray coding -> 0: 00, 1: 01, 2: 11, 3: 10
bec_map = {0: 0, 1: 1, 2: 2, 3: 1, np.nan: 2}
ser = {}
ber = {}
ser["n_symbols"] = len(symbols_gt)
ser["n_errors"] = np.sum(symbols != symbols_gt)
ser["total"] = float(ser["n_errors"] / ser["n_symbols"])
bec = np.vectorize(bec_map.get)(np.abs(symbols - symbols_gt))
bit_errors = np.sum(bec)
ber["n_bits"] = len(symbols_gt) * 2
ber["n_errors"] = bit_errors
ber["total"] = float(ber["n_errors"] / ber["n_bits"])
return ser, ber
@staticmethod
def print_dict(d: dict, indent=2, logarithmic=False, level=0):
for key, value in d.items():
if isinstance(value, dict):
print(f"{' ' * indent * level}{key}:")
evaluator.print_dict(value, indent=indent, logarithmic=logarithmic, level=level + 1)
else:
if isinstance(value, float):
if logarithmic:
if value == 0:
value = -np.inf
else:
value = np.log10(value)
print(f"{' ' * indent * level}{key}: {value:.2e}\t", end="")
else:
print(f"{' ' * indent * level}{key}: {value}\t", end="")
print()
@staticmethod
def detect_symbols(samples, thresholds=None):
thresholds = (1 / 6, 3 / 6, 5 / 6) if thresholds is None else thresholds
thresholds = (-np.inf, *thresholds, np.inf)
symbols = np.digitize(samples, thresholds) - 1
return symbols
@staticmethod
def prepare_data(data, sps=None):
data = data.transpose(1, 0)
timestamps = data[-1].real
data = data[:-1]
if sps is not None:
timestamps /= sps
# data = np.stack(
# (
# *data[0:2], # fiber_in_x, fiber_in_y
# # *data_[2:4], # fiber_out_x, fiber_out_y
# *data[4:6], # fiber_out_noisy_x, fiber_out_noisy_y
# *data[6:8], # regen_out_x, regen_out_y
# ),
# axis=0,
# )
data_eye = []
for channel_values in data:
channel_values = np.square(np.abs(channel_values))
data_eye.append(np.stack((timestamps, channel_values), axis=0))
data_eye = np.stack(data_eye, axis=0)
return data_eye
def generate_data(parameters, runner=None):
runner = runner or model_runner()
for param in parameters:
runner.update_parameter_range(*param)
runner.generate_iterations()
print(f"{runner.iter.length} parameter combinations")
runner.generate_datasets()
return runner
if __name__ == "__main__":
model_path = ".models/best_20250110_191149.tar" # D 17, OSNR 100, delta_beta 0.14, baud 10e9
parameters = (
# name, config keys, (min, max), n_steps, log
# ("D", ("fiber", "d"), (28,30), 3, False),
# ("S", ("fiber", "s"), (0, 0.058), 2, False),
("OSNR", ("signal", "osnr"), (20, 40), 5, False),
# ("PMD", ("fiber", "max_delta_beta"), (0, 0.28), 3, False),
# ("Baud", ("glova", "symbolrate"), (10e9, 100e9), 3, True),
)
datasets = None
results = None
# datasets = "tolerance_results/datasets/datasets_list_20250110_223337.h5"
results = "tolerance_results/datasets/results_list_20250110_232639.h5"
runner = model_runner()
# generate_data(parameters, runner)
if results is None:
if datasets is None:
generate_data(parameters, runner)
else:
runner.load_datasets_from_file(datasets)
print(f"{len(runner.datasets)} loaded")
runner.run_model_evaluation(model_path)
else:
runner.load_results_from_file(results)
# print(runner.parameters)
# print(runner.results)
eval = evaluator(runner.results)
eval.evaluate_datasets(plot=True)