model robustness testing

This commit is contained in:
Joseph Hopfmüller
2025-01-10 23:40:54 +01:00
parent 3af73343c1
commit f38d0ca3bb
13 changed files with 1558 additions and 334 deletions

View File

@@ -0,0 +1,671 @@
"""
generate_signal.py
This file is part of the repo "optical-regeneration"
https://git.suuppl.dev/seppl/optical-regeneration.git
Joseph Hopfmüller
Copyright 2024
Licensed under the EUPL
Full license text in LICENSE file
"""
import configparser
from datetime import datetime
import hashlib
from pathlib import Path
import time
import h5py
from matplotlib import pyplot as plt # noqa: F401
import numpy as np
from . import add_pypho # noqa: F401
import pypho
default_config = f"""
[glova]
sps = 128
nos = 16384
f0 = 193414489032258.06
symbolrate = 10e9
wisdom_dir = "{str((Path.home() / ".pypho"))}"
flags = "FFTW_PATIENT"
nthreads = 32
[fiber]
length = 10000
gamma = 1.14
alpha = 0.2
D = 17
S = 0.058
bireflength = 10
max_delta_beta = 0.14
; birefseed = 0xC0FFEE
[signal]
; seed = 0xC0FFEE
modulation = "pam"
mod_order = 4
mod_depth = 1
max_jitter = 0.02
; jitter_seed = 0xC0FFEE
laser_power = 0
edfa_power = 0
edfa_nf = 5
pulse_shape = "gauss"
fwhm = 0.33
osnr = "inf"
[data]
dir = "data"
npy_dir = "npys"
"""
def get_config(config_file=None):
"""
DANGER! The function uses eval() to parse the config file. Do not use this function with untrusted input.
"""
if config_file is None:
config_file = Path(__file__).parent / "signal_generation.ini"
config_file = Path(config_file)
if not config_file.exists():
with open(config_file, "w") as f:
f.write(default_config)
config = configparser.ConfigParser()
config.read(config_file)
conf = {}
for section in config.sections():
# print(f"[{section}]")
conf[section] = {}
for key in config[section]:
# print(f"{key} = {config[section][key]}")
try:
conf[section][key] = eval(config[section][key])
except NameError:
conf[section][key] = float(config[section][key])
# if isinstance(conf[section][key], str):
# conf[section][key] = config[section][key].strip('"')
return conf
class PDM_IM_IPM:
def __init__(
self,
glova,
mod_order=8,
seed=None,
):
assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, (
"mod_order must be a cube of an integer greater than 1"
)
self.glova = glova
self.mod_order = mod_order
self.symbols_per_dim = int(np.cbrt(mod_order))
self.seed = seed
def generate_symbols(self, n):
rs = np.random.RandomState(self.seed)
symbols = rs.randint(0, self.mod_order, n)
return symbols
class pam_generator:
def __init__(
self, glova, mod_order=None, mod_depth=0.5, pulse_shape="gauss", fwhm=0.33, seed=None, single_channel=False
) -> None:
self.glova = glova
self.pulse_shape = pulse_shape
self.modulation_depth = mod_depth
self.mod_order = mod_order
self.fwhm = fwhm
self.seed = seed
self.single_channel = single_channel
def __call__(self, E, symbols, max_jitter=0):
max_jitter = int(round(max_jitter * self.glova.sps))
if self.pulse_shape == "gauss":
wavelet = self.gauss(oversampling=6)
else:
raise ValueError(f"Unknown pulse shape: {self.pulse_shape}")
# prepare symbols
symbols_x = symbols[0] / (self.mod_order)
diffs_x = np.diff(symbols_x, prepend=symbols_x[0])
digital_x = self.generate_digital_signal(diffs_x, max_jitter)
digital_x = np.pad(digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
# create analog signal of diff of symbols
E_x = np.convolve(digital_x, wavelet)
# convert to pam and set modulation depth (scale and move up such that 1 stays at 1)
E_x = np.cumsum(E_x) * self.modulation_depth + (1 - self.modulation_depth)
# cut off the wavelet tails
E_x = E_x[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
# modulate the laser
E[0]["E"][0] = np.sqrt(np.multiply(np.square(E[0]["E"][0]), E_x))
if not self.single_channel:
symbols_y = symbols[1] / (self.mod_order)
diffs_y = np.diff(symbols_y, prepend=symbols_y[0])
digital_y = self.generate_digital_signal(diffs_y, max_jitter)
digital_y = np.pad(digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
E_y = np.convolve(digital_y, wavelet)
E_y = np.cumsum(E_y) * self.modulation_depth + (1 - self.modulation_depth)
E_y = E_y[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
E[0]["E"][1] = np.sqrt(np.multiply(np.square(E[0]["E"][1]), E_y))
# rotate the signal on the y-polarisation by 90°
# E[0]["E"][1] *= 1j
else:
E[0]["E"][1] = np.zeros_like(E[0]["E"][0], dtype=E[0]["E"][0].dtype)
return E
def generate_digital_signal(self, symbols, max_jitter=0):
rs = np.random.RandomState(self.seed)
signal = np.zeros(self.glova.nos * self.glova.sps)
for index in range(self.glova.nos):
jitter = max_jitter != 0 and rs.randint(-max_jitter, max_jitter)
signal_index = index * self.glova.sps + jitter
if signal_index < 0:
continue
if signal_index >= len(signal):
continue
signal[signal_index] = symbols[index]
return signal
def gauss(self, oversampling=1):
sample_points = np.linspace(
-oversampling * self.glova.sps,
oversampling * self.glova.sps,
oversampling * 2 * self.glova.sps,
endpoint=True,
)
sigma = self.fwhm / (1 * np.sqrt(2 * np.log(2))) * self.glova.sps
pulse = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
return pulse
def initialize_fiber_and_data(config):
py_glova = pypho.setup(
nos=config["glova"]["nos"],
sps=config["glova"]["sps"],
f0=config["glova"]["f0"],
symbolrate=config["glova"]["symbolrate"],
wisdom_dir=config["glova"]["wisdom_dir"],
flags=config["glova"]["flags"],
nthreads=config["glova"]["nthreads"],
)
c_glova = pypho.cfiber.GlovaWrapper.from_setup(py_glova)
c_data = pypho.cfiber.DataWrapper(py_glova.sps * py_glova.nos)
py_edfa = pypho.oamp(py_glova, Pmean=config["signal"]["edfa_power"], NF=config["signal"]["edfa_nf"])
osnr = config["signal"]["osnr"] = float(config["signal"].get("osnr", "inf"))
config["signal"]["seed"] = config["signal"].get("seed", (int(time.time() * 1000)) % 2**32)
config["signal"]["jitter_seed"] = config["signal"].get("jitter_seed", (int(time.time() * 1000)) % 2**32)
symbolsrc = pypho.symbols(
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
)
laser = pypho.lasmod(py_glova, power=config["signal"]["laser_power"], Df=0, theta=np.pi / 4)
modulator = pam_generator(
py_glova,
mod_depth=config["signal"]["mod_depth"],
pulse_shape=config["signal"]["pulse_shape"],
fwhm=config["signal"]["fwhm"],
seed=config["signal"]["jitter_seed"],
mod_order=config["signal"]["mod_order"],
)
symbols_x = symbolsrc(pattern="random")
symbols_y = symbolsrc(pattern="random")
symbols_x[:3] = 0
symbols_y[:3] = 0
# symbols_x += 1
cw = laser()
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
if osnr != float("inf"):
osnr_lin = 10 ** (osnr / 10)
signal_power = np.sum(pypho.functions.getpower_W(source_signal[0]["E"]))
noise_power = signal_power / osnr_lin
noise = np.random.normal(0, 1, source_signal[0]["E"].shape) + 1j * np.random.normal(
0, 1, source_signal[0]["E"].shape
)
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
noise = noise * np.sqrt(noise_power / noise_power_is)
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
source_signal[0]["E"] += noise
source_signal[0]["noise"] = noise_power_is
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
nf = py_edfa.NF
source_signal = py_edfa(E=source_signal, NF=0)
py_edfa.NF = nf
c_data.E_in = source_signal[0]["E"]
noise = source_signal[0]["noise"]
py_fiber = pypho.fiber(
glova=py_glova,
l=config["fiber"]["length"],
alpha=pypho.functions.dB_to_Neper(config["fiber"]["alpha"]) / 1000,
gamma=config["fiber"]["gamma"],
D=config["fiber"]["d"],
S=config["fiber"]["s"],
)
if config["fiber"].get("birefsteps", 0) > 0:
seed = config["fiber"].get("birefseed", (int(time.time() * 1000)) % 2**32)
py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre(
py_fiber.l,
py_fiber.l / config["fiber"]["birefsteps"],
# maxDeltaD=config["fiber"]["d"]/5,
maxDeltaBeta=config["fiber"].get("max_delta_beta", 0),
seed=seed,
)
c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200)
c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova)
return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y)
def save_data(data, config, **metadata):
data_dir = Path(config["data"]["dir"])
npy_dir = config["data"].get("npy_dir", "")
save_dir = data_dir / npy_dir if len(npy_dir) else data_dir
save_dir.mkdir(parents=True, exist_ok=True)
save_data = np.column_stack([
data.E_in[0],
data.E_in[1],
data.E_out[0],
data.E_out[1],
])
timestamp = datetime.now()
seed = config["signal"].get("seed", False)
jitter_seed = config["signal"].get("jitter_seed", False)
birefseed = config["fiber"].get("birefseed", False)
osnr = float(config["signal"].get("osnr", "inf"))
config_content = "\n".join((
f"; Generated by {str(Path(__file__).name)} @ {timestamp.strftime('%Y-%m-%d %H:%M:%S')}",
"[glova]",
f"sps = {config['glova']['sps']}",
f"nos = {config['glova']['nos']}",
f"f0 = {config['glova']['f0']}",
f"symbolrate = {config['glova']['symbolrate']}",
f'wisdom_dir = "{config["glova"]["wisdom_dir"]}"',
f'flags = "{config["glova"]["flags"]}"',
f"nthreads = {config['glova']['nthreads']}",
"",
"[fiber]",
f"length = {config['fiber']['length']}",
f"gamma = {config['fiber']['gamma']}",
f"alpha = {config['fiber']['alpha']}",
f"D = {config['fiber']['d']}",
f"S = {config['fiber']['s']}",
f"birefsteps = {config['fiber'].get('birefsteps', 0)}",
f"max_delta_beta = {config['fiber'].get('max_delta_beta', 0)}",
f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set",
"",
"[signal]",
f"seed = {hex(seed)}" if seed else "; seed = not set",
"",
f'modulation = "{config["signal"]["modulation"]}"',
f"mod_order = {config['signal']['mod_order']}",
f"mod_depth = {config['signal']['mod_depth']}",
"",
f"max_jitter = {config['signal']['max_jitter']}",
f"jitter_seed = {hex(jitter_seed)}" if jitter_seed else "; jitter_seed = not set",
"",
f"laser_power = {config['signal']['laser_power']}",
f"edfa_power = {config['signal']['edfa_power']}",
f"edfa_nf = {config['signal']['edfa_nf']}",
f"osnr = {osnr}",
"",
f'pulse_shape = "{config["signal"]["pulse_shape"]}"',
f"fwhm = {config['signal']['fwhm']}",
"",
"[data]",
f'dir = "{str(data_dir)}"',
f'npy_dir = "{npy_dir}"',
"file = ",
))
config_hash = hashlib.md5(config_content.encode()).hexdigest()
save_file = f"{config_hash}.h5"
config_content += f'"{str(save_file)}"\n'
filename_components = (
timestamp.strftime("%Y%m%d-%H%M%S"),
config["glova"]["sps"],
config["glova"]["nos"],
config["signal"]["osnr"],
config["fiber"]["length"],
config["fiber"]["gamma"],
config["fiber"]["alpha"],
config["fiber"]["d"],
config["fiber"]["s"],
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
config["fiber"].get("birefsteps", 0),
config["fiber"].get("max_delta_beta", 0),
int(config["glova"]["symbolrate"] / 1e9),
)
lookup_file = "-".join(map(str, filename_components)) + ".ini"
config_filename = data_dir / lookup_file
with open(config_filename, "w") as f:
f.write(config_content)
with h5py.File(save_dir / save_file, "w") as outfile:
outfile.create_dataset("data", data=save_data)
outfile.create_dataset("symbols", data=metadata.pop("symbols"))
for key, value in metadata.items():
# if isinstance(value, dict):
# value = json.dumps(model_runner.convert_arrays(value))
outfile.attrs[key] = value
# np.save(save_dir / save_file, save_data)
print("Saved config to", config_filename)
print("Saved data to", save_dir / save_file)
return config_filename
def length_loop(config, lengths, save=True):
lengths = sorted(lengths)
for length in lengths:
print(f"\nGenerating data for fiber length {length}m")
config["fiber"]["length"] = length
cfiber, cdata, noise, edfa = initialize_fiber_and_data(config)
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
cfiber()
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)")
print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)")
E_tmp = [{"E": cdata.E_out, "noise": noise * (-cfiber.params.l * cfiber.params.alpha)}]
E_tmp = edfa(E=E_tmp)
cdata.E_out = E_tmp[0]["E"]
mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out))
print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)")
if save:
save_data(cdata, config)
in_out_eyes(cfiber, cdata)
def single_run_with_plot(config, save=True):
cfiber, cdata, config_filename = single_run(config, save)
in_out_eyes(cfiber, cdata, show_pols=False)
return config_filename
def single_run(config, save=True):
cfiber, cdata, noise, edfa, symbols = initialize_fiber_and_data(config)
# mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
# print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)")
# estimate osnr
# noise_power = np.mean(noise)
# osnr_lin = mean_power_in / noise_power - 1
# osnr = 10 * np.log10(osnr_lin)
# print(f"Estimated OSNR: {osnr:.3f} dB")
cfiber()
# mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
# print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)")
# noise = noise * np.exp(-cfiber.params.l * cfiber.params.alpha)
# estimate osnr
# noise_power = np.mean(noise)
# osnr_lin = mean_power_out / noise_power - 1
# osnr = 10 * np.log10(osnr_lin)
# print(f"Estimated OSNR: {osnr:.3f} dB")
E_tmp = [{"E": cdata.E_out, "noise": noise}]
E_tmp = edfa(E=E_tmp)
cdata.E_out = E_tmp[0]["E"]
# noise = E_tmp[0]["noise"]
# mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out))
# print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)")
# estimate osnr
# noise_power = np.mean(noise)
# osnr_lin = mean_power_amp / noise_power - 1
# osnr = 10 * np.log10(osnr_lin)
# print(f"Estimated OSNR: {osnr:.3f} dB")
config_filename = None
symbols = np.array(symbols)
if save:
config_filename = save_data(cdata, config, **{"symbols": symbols})
return cfiber,cdata,config_filename
def in_out_eyes(cfiber, cdata, show_pols=False):
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
eye_head = min(cfiber.glova.nos, 2000)
symbolrate_scale = 1e12
amplitude_scale = 1e3
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_in[0]) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[0][0],
show=False,
color="C0",
)
if show_pols:
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_in[0].real) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[0][0],
color="C2",
show=False,
)
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_in[0].imag) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[0][0],
color="C3",
show=False,
)
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_out[0]) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[0][1],
color="C1",
show=False,
)
if show_pols:
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_out[0].real) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[0][1],
color="C4",
show=False,
)
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_out[0].imag) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[0][1],
color="C5",
show=False,
)
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_in[1]) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[1][0],
color="C0",
show=False,
)
if show_pols:
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_in[1].real) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[1][0],
color="C2",
show=False,
)
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_in[1].imag) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[1][0],
color="C3",
show=False,
)
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_out[1]) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[1][1],
color="C1",
show=False,
)
if show_pols:
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_out[1].real) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[0][1],
color="C4",
show=False,
)
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_out[1].imag) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[0][1],
color="C5",
show=False,
)
title_map = [
["Input x", "Output x"],
["Input y", "Output y"],
]
title_map = np.array(title_map)
for ax, title in zip(axs.flatten(), title_map.flatten()):
ax.grid(True)
ax.set_xlabel("Time [ps]")
ax.set_ylabel("Power [mW]")
ax.set_title(title)
fig.tight_layout()
plt.show()
def plot_eye_diagram(
signal: np.ndarray,
eye_width,
offset=0,
*,
head=None,
samplerate=1,
normalize=True,
ax=None,
color="C0",
show=True,
):
ax = ax or plt.gca()
if head is not None:
signal = signal[: head * eye_width]
if normalize:
signal = signal / np.max(signal)
slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[offset % (eye_width + 1) :: eye_width]
plt_ax = np.arange(-eye_width // 2, eye_width // 2 + 1) / samplerate
for slice in slices:
ax.plot(plt_ax, slice, color=color, alpha=0.1)
ax.grid()
if show:
plt.show()
if __name__ == "__main__":
add_pypho.show_log()
config = get_config()
# ranges = (1000,10000)
# scales = tuple(range(1, 10))
# scales = (1,)
# lengths = [range_ * scale for range_ in ranges for scale in scales]
# lengths.append(10*max(ranges))
# lengths = [*lengths, *lengths]
lengths = (
# 8000, 9000,
10000,
20000,
30000,
40000,
50000,
60000,
70000,
80000,
90000,
95000,
100000,
105000,
110000,
115000,
120000,
)
# lengths = (10000,100000)
# length_loop(config, lengths, save=True)
# birefringence is constant over coupling length -> several 100m -> bireflength=1000 (m)
single_run_with_plot(config, save=False)