wip
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -163,4 +163,5 @@ cython_debug/
|
|||||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
tolerance_results/datasets/*
|
tolerance_results/*
|
||||||
|
data/*
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:fcbdaffa211d6b0b44b3ae1c66645999e95901bfdb2fffee4c45e34a0d901ee1
|
||||||
|
size 649
|
||||||
3
data/npys/6789fdea2609799ef2e975907625b79a.h5
Normal file
3
data/npys/6789fdea2609799ef2e975907625b79a.h5
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:1df90745cc2e6d4b0ad964fca2de1441e6e0b4b8345fbb0fbc1ffe9820674269
|
||||||
|
size 134481920
|
||||||
59
notes/tolerance_testing.md
Normal file
59
notes/tolerance_testing.md
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# Baseline Models
|
||||||
|
|
||||||
|
## a) D+S, pol_error 0, ortho_error 0, DGD 0
|
||||||
|
|
||||||
|
dataset
|
||||||
|
|
||||||
|
```raw
|
||||||
|
data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||||
|
```
|
||||||
|
|
||||||
|
model
|
||||||
|
|
||||||
|
```raw
|
||||||
|
.models/best_20250118_225918.tar
|
||||||
|
```
|
||||||
|
|
||||||
|
## b) D+S, pol_error 0.4, ortho_error 0, DGD 0
|
||||||
|
|
||||||
|
dataset
|
||||||
|
|
||||||
|
```raw
|
||||||
|
data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||||
|
```
|
||||||
|
|
||||||
|
model
|
||||||
|
|
||||||
|
```raw
|
||||||
|
.models/best_20250116_214816.tar
|
||||||
|
```
|
||||||
|
|
||||||
|
## c) D+S, pol_error 0, ortho_error 0.1, DGD 0
|
||||||
|
|
||||||
|
dataset
|
||||||
|
|
||||||
|
```raw
|
||||||
|
data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||||
|
```
|
||||||
|
|
||||||
|
model
|
||||||
|
|
||||||
|
```raw
|
||||||
|
.models/best_20250117_122319.tar
|
||||||
|
```
|
||||||
|
|
||||||
|
## d) D+S, pol_error 0, ortho_error 0, DGD 10ps (1 T_sym)
|
||||||
|
|
||||||
|
birefringence angle pi/2 (worst case)
|
||||||
|
|
||||||
|
dataset
|
||||||
|
|
||||||
|
```raw
|
||||||
|
data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||||
|
```
|
||||||
|
|
||||||
|
model
|
||||||
|
|
||||||
|
```raw
|
||||||
|
.models/best_20250117_144001.tar
|
||||||
|
```
|
||||||
2
pypho
2
pypho
Submodule pypho updated: dd015f4852...e44fc477fe
@@ -164,10 +164,14 @@ class regenerator(Module):
|
|||||||
module = act_function(size=dims[-1], **act_func_kwargs)
|
module = act_function(size=dims[-1], **act_func_kwargs)
|
||||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
|
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
|
||||||
|
|
||||||
|
module = Scale(size=dims[-1])
|
||||||
|
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
|
||||||
|
|
||||||
if self.rotation:
|
if self.rotation:
|
||||||
module = rotate()
|
module = rotate()
|
||||||
self.add_module("rotate", module)
|
self.add_module("rotate", module)
|
||||||
|
|
||||||
|
|
||||||
# module = Scale(size=dims[-1])
|
# module = Scale(size=dims[-1])
|
||||||
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
|
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
|
||||||
|
|
||||||
|
|||||||
@@ -18,9 +18,11 @@ class DataSettings:
|
|||||||
shuffle: bool = True
|
shuffle: bool = True
|
||||||
in_out_delay: float = 0
|
in_out_delay: float = 0
|
||||||
xy_delay: tuple | float | int = 0
|
xy_delay: tuple | float | int = 0
|
||||||
drop_first: int = 1000
|
drop_first: int = 64
|
||||||
|
drop_last: int = 64
|
||||||
train_split: float = 0.8
|
train_split: float = 0.8
|
||||||
polarisations: tuple | list = (0,)
|
polarisations: tuple | list = (0,)
|
||||||
|
# cross_pol_interference: float = 0
|
||||||
randomise_polarisations: bool = False
|
randomise_polarisations: bool = False
|
||||||
osnr: float | int = None
|
osnr: float | int = None
|
||||||
seed: int = None
|
seed: int = None
|
||||||
@@ -93,6 +95,12 @@ class ModelSettings:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _early_stop_default_kwargs():
|
||||||
|
return {
|
||||||
|
"threshold": 1e-05,
|
||||||
|
"plateau": 25,
|
||||||
|
}
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OptimizerSettings:
|
class OptimizerSettings:
|
||||||
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
|
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
|
||||||
@@ -101,6 +109,9 @@ class OptimizerSettings:
|
|||||||
scheduler: str | None = None
|
scheduler: str | None = None
|
||||||
scheduler_kwargs: dict | None = None
|
scheduler_kwargs: dict | None = None
|
||||||
|
|
||||||
|
early_stopping: bool = False
|
||||||
|
early_stop_kwargs: dict = field(default_factory=_early_stop_default_kwargs)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
change to:
|
change to:
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
|||||||
import random
|
import random
|
||||||
import matplotlib
|
import matplotlib
|
||||||
from matplotlib.colors import LinearSegmentedColormap
|
from matplotlib.colors import LinearSegmentedColormap
|
||||||
|
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||||
import torch.nn.utils.parametrize
|
import torch.nn.utils.parametrize
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -46,13 +47,72 @@ from .settings import (
|
|||||||
PytorchSettings,
|
PytorchSettings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from cmcrameri import cm
|
||||||
|
# from matplotlib import colors as mcolors
|
||||||
|
# alpha_map = mcolors.LinearSegmentedColormap(
|
||||||
|
# 'alphamap',
|
||||||
|
# {
|
||||||
|
# 'red': [(0, 0, 0), (1, 0, 0)],
|
||||||
|
# 'green': [(0, 0, 0), (1, 0, 0)],
|
||||||
|
# 'blue': [(0, 0, 0), (1, 0, 0)],
|
||||||
|
# 'alpha': [
|
||||||
|
# (0, 1, 1),
|
||||||
|
# # (0.2, 0.2, 0.1),
|
||||||
|
# (1, 0, 0)
|
||||||
|
# ]
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
# alpha_map.set_bad(color="#AAAAAA")
|
||||||
|
|
||||||
|
def pad_to_size(array, size):
|
||||||
|
if not hasattr(size, "__len__"):
|
||||||
|
size = (size, size)
|
||||||
|
|
||||||
|
left = (
|
||||||
|
(size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0
|
||||||
|
)
|
||||||
|
right = (
|
||||||
|
(size[0] - array.shape[0]) // 2 if size[0] is not None else 0
|
||||||
|
)
|
||||||
|
top = (
|
||||||
|
(size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0
|
||||||
|
)
|
||||||
|
bottom = (
|
||||||
|
(size[1] - array.shape[1]) // 2 if size[1] is not None else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
array: np.ndarray = array
|
||||||
|
if array.ndim == 2:
|
||||||
|
return np.pad(
|
||||||
|
array,
|
||||||
|
(
|
||||||
|
(left, right),
|
||||||
|
(top, bottom),
|
||||||
|
),
|
||||||
|
constant_values=(np.nan, np.nan),
|
||||||
|
)
|
||||||
|
elif array.ndim == 3:
|
||||||
|
return np.pad(
|
||||||
|
array,
|
||||||
|
(
|
||||||
|
(left, right),
|
||||||
|
(top, bottom),
|
||||||
|
(0,0)
|
||||||
|
),
|
||||||
|
constant_values=(np.nan, np.nan),
|
||||||
|
)
|
||||||
|
|
||||||
def traverse_dict_update(target, source):
|
def traverse_dict_update(target, source):
|
||||||
for k, v in source.items():
|
for k, v in source.items():
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
|
try:
|
||||||
if k not in target:
|
if k not in target:
|
||||||
target[k] = {}
|
target[k] = {}
|
||||||
traverse_dict_update(target[k], v)
|
traverse_dict_update(target[k], v)
|
||||||
|
except TypeError:
|
||||||
|
if k not in target.__dict__:
|
||||||
|
setattr(target, k, {})
|
||||||
|
traverse_dict_update(target.__dict__[k], v)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
target[k] = v
|
target[k] = v
|
||||||
@@ -261,6 +321,7 @@ class PolarizationTrainer:
|
|||||||
target_delay=in_out_delay,
|
target_delay=in_out_delay,
|
||||||
xy_delay=xy_delay,
|
xy_delay=xy_delay,
|
||||||
drop_first=self.data_settings.drop_first,
|
drop_first=self.data_settings.drop_first,
|
||||||
|
drop_last=self.data_settings.drop_last,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
real=not dtype.is_complex,
|
real=not dtype.is_complex,
|
||||||
num_symbols=num_symbols,
|
num_symbols=num_symbols,
|
||||||
@@ -602,6 +663,7 @@ class RegenerationTrainer:
|
|||||||
console=None,
|
console=None,
|
||||||
checkpoint_path=None,
|
checkpoint_path=None,
|
||||||
settings_override=None,
|
settings_override=None,
|
||||||
|
new_model=False,
|
||||||
reset_epoch=False,
|
reset_epoch=False,
|
||||||
):
|
):
|
||||||
self.resume = checkpoint_path is not None
|
self.resume = checkpoint_path is not None
|
||||||
@@ -615,12 +677,23 @@ class RegenerationTrainer:
|
|||||||
models.regenerator,
|
models.regenerator,
|
||||||
torch.nn.utils.parametrizations.orthogonal,
|
torch.nn.utils.parametrizations.orthogonal,
|
||||||
])
|
])
|
||||||
|
# self.new_model = True
|
||||||
|
self.model_name = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
if self.resume:
|
if self.resume:
|
||||||
print(f"loading checkpoint from {checkpoint_path}")
|
print(f"loading checkpoint from {checkpoint_path}")
|
||||||
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
|
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
|
||||||
if settings_override is not None:
|
if settings_override is not None:
|
||||||
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
|
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
|
||||||
if reset_epoch:
|
|
||||||
|
if not new_model:
|
||||||
|
# self.new_model = False
|
||||||
|
checkpoint_file = checkpoint_path.split("/")[-1].split(".")[0]
|
||||||
|
if checkpoint_file.startswith("best"):
|
||||||
|
self.model_name = "_".join(checkpoint_file.split("_")[1:])
|
||||||
|
else:
|
||||||
|
self.model_name = "_".join(checkpoint_file.split("_")[:-1])
|
||||||
|
|
||||||
|
if new_model or reset_epoch:
|
||||||
self.checkpoint_dict["epoch"] = -1
|
self.checkpoint_dict["epoch"] = -1
|
||||||
|
|
||||||
self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"]
|
self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"]
|
||||||
@@ -654,7 +727,7 @@ class RegenerationTrainer:
|
|||||||
self.writer = None
|
self.writer = None
|
||||||
|
|
||||||
def setup_tb_writer(self, append=None):
|
def setup_tb_writer(self, append=None):
|
||||||
log_dir = self.pytorch_settings.summary_dir + "/" + (datetime.now().strftime("%Y%m%d_%H%M%S"))
|
log_dir = self.pytorch_settings.summary_dir + "/" + self.model_name
|
||||||
if append is not None:
|
if append is not None:
|
||||||
log_dir += "_" + str(append)
|
log_dir += "_" + str(append)
|
||||||
|
|
||||||
@@ -697,7 +770,7 @@ class RegenerationTrainer:
|
|||||||
|
|
||||||
output_dim = self.model_settings.output_dim
|
output_dim = self.model_settings.output_dim
|
||||||
|
|
||||||
# if self.data_settings.polarisations:
|
if self.data_settings.polarisations:
|
||||||
output_dim *= 2
|
output_dim *= 2
|
||||||
|
|
||||||
dtype = getattr(torch, self.data_settings.dtype)
|
dtype = getattr(torch, self.data_settings.dtype)
|
||||||
@@ -755,11 +828,13 @@ class RegenerationTrainer:
|
|||||||
randomise_polarisations = self.data_settings.randomise_polarisations
|
randomise_polarisations = self.data_settings.randomise_polarisations
|
||||||
polarisations = self.data_settings.polarisations
|
polarisations = self.data_settings.polarisations
|
||||||
osnr = self.data_settings.osnr
|
osnr = self.data_settings.osnr
|
||||||
|
# cross_pol_interference = self.data_settings.cross_pol_interference
|
||||||
if override is not None:
|
if override is not None:
|
||||||
num_symbols = override.get("num_symbols", None)
|
num_symbols = override.get("num_symbols", None)
|
||||||
config_path = override.get("config_path", config_path)
|
config_path = override.get("config_path", config_path)
|
||||||
polarisations = override.get("polarisations", polarisations)
|
polarisations = override.get("polarisations", polarisations)
|
||||||
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
|
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
|
||||||
|
# cross_pol_interference = override.get("angle_var", 0)
|
||||||
# get dataset
|
# get dataset
|
||||||
dataset = FiberRegenerationDataset(
|
dataset = FiberRegenerationDataset(
|
||||||
file_path=config_path,
|
file_path=config_path,
|
||||||
@@ -768,11 +843,13 @@ class RegenerationTrainer:
|
|||||||
target_delay=in_out_delay,
|
target_delay=in_out_delay,
|
||||||
xy_delay=xy_delay,
|
xy_delay=xy_delay,
|
||||||
drop_first=self.data_settings.drop_first,
|
drop_first=self.data_settings.drop_first,
|
||||||
|
drop_last=self.data_settings.drop_last,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
real=not dtype.is_complex,
|
real=not dtype.is_complex,
|
||||||
num_symbols=num_symbols,
|
num_symbols=num_symbols,
|
||||||
randomise_polarisations=randomise_polarisations,
|
randomise_polarisations=randomise_polarisations,
|
||||||
polarisations=polarisations,
|
polarisations=polarisations,
|
||||||
|
# cross_pol_interference=cross_pol_interference,
|
||||||
osnr = osnr,
|
osnr = osnr,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -842,8 +919,10 @@ class RegenerationTrainer:
|
|||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
self.model.train()
|
self.model.train()
|
||||||
loader_len = len(train_loader)
|
loader_len = len(train_loader)
|
||||||
x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
||||||
y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
||||||
|
x_key = "x"
|
||||||
|
y_key = "y"
|
||||||
for batch_idx, batch in enumerate(train_loader):
|
for batch_idx, batch in enumerate(train_loader):
|
||||||
x = batch[x_key]
|
x = batch[x_key]
|
||||||
y = batch[y_key]
|
y = batch[y_key]
|
||||||
@@ -855,7 +934,10 @@ class RegenerationTrainer:
|
|||||||
angle.to(self.pytorch_settings.device),
|
angle.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
y_pred = self.model(x, -angle)
|
y_pred = self.model(x, -angle)
|
||||||
|
# loss = util.complexNN.complex_mse_loss(y_pred*torch.sqrt(torch.ones(1, device=y.device)*5), y*torch.sqrt(torch.ones(1, device=y.device)*5), power=True)
|
||||||
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||||
|
|
||||||
|
|
||||||
loss_value = loss.item()
|
loss_value = loss.item()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@@ -898,8 +980,10 @@ class RegenerationTrainer:
|
|||||||
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
running_error = 0
|
running_error = 0
|
||||||
x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
||||||
y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
||||||
|
x_key = "x"
|
||||||
|
y_key = "y"
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _, batch in enumerate(valid_loader):
|
for _, batch in enumerate(valid_loader):
|
||||||
x = batch[x_key]
|
x = batch[x_key]
|
||||||
@@ -911,7 +995,9 @@ class RegenerationTrainer:
|
|||||||
angle.to(self.pytorch_settings.device),
|
angle.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
y_pred = self.model(x, -angle)
|
y_pred = self.model(x, -angle)
|
||||||
|
# error = util.complexNN.complex_mse_loss(y_pred*torch.sqrt(torch.ones(1, device=y.device)*5), y*torch.sqrt(torch.ones(1, device=y.device)*5), power=True)
|
||||||
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||||
|
|
||||||
error_value = error.item()
|
error_value = error.item()
|
||||||
running_error += error_value
|
running_error += error_value
|
||||||
|
|
||||||
@@ -928,7 +1014,7 @@ class RegenerationTrainer:
|
|||||||
if (epoch + 1) % 10 == 0 or epoch < 10:
|
if (epoch + 1) % 10 == 0 or epoch < 10:
|
||||||
# plotting is slow, so only do it every 10 epochs
|
# plotting is slow, so only do it every 10 epochs
|
||||||
title_append, subtitle = self.build_title(epoch + 1)
|
title_append, subtitle = self.build_title(epoch + 1)
|
||||||
head_fig, eye_fig, powers_fig = self.plot_model_response(
|
head_fig, eye_fig, weight_fig, powers_fig = self.plot_model_response(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
title_append=title_append,
|
title_append=title_append,
|
||||||
subtitle=subtitle,
|
subtitle=subtitle,
|
||||||
@@ -944,6 +1030,11 @@ class RegenerationTrainer:
|
|||||||
eye_fig,
|
eye_fig,
|
||||||
epoch + 1,
|
epoch + 1,
|
||||||
)
|
)
|
||||||
|
self.writer.add_figure(
|
||||||
|
"weights",
|
||||||
|
weight_fig,
|
||||||
|
epoch + 1,
|
||||||
|
)
|
||||||
|
|
||||||
self.writer.add_figure(
|
self.writer.add_figure(
|
||||||
"powers",
|
"powers",
|
||||||
@@ -967,9 +1058,10 @@ class RegenerationTrainer:
|
|||||||
regen = []
|
regen = []
|
||||||
timestamps = []
|
timestamps = []
|
||||||
angles = []
|
angles = []
|
||||||
x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
||||||
y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
||||||
|
x_key = "x"
|
||||||
|
y_key = "y"
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model = model.to(self.pytorch_settings.device)
|
model = model.to(self.pytorch_settings.device)
|
||||||
for batch in loader:
|
for batch in loader:
|
||||||
@@ -1056,7 +1148,7 @@ class RegenerationTrainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
title_append, subtitle = self.build_title(0)
|
title_append, subtitle = self.build_title(0)
|
||||||
head_fig, eye_fig, powers_fig = self.plot_model_response(
|
head_fig, eye_fig, weight_fig, powers_fig = self.plot_model_response(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
title_append=title_append,
|
title_append=title_append,
|
||||||
subtitle=subtitle,
|
subtitle=subtitle,
|
||||||
@@ -1072,6 +1164,11 @@ class RegenerationTrainer:
|
|||||||
eye_fig,
|
eye_fig,
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
self.writer.add_figure(
|
||||||
|
"weights",
|
||||||
|
weight_fig,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
self.writer.add_figure(
|
self.writer.add_figure(
|
||||||
"powers",
|
"powers",
|
||||||
@@ -1103,6 +1200,9 @@ class RegenerationTrainer:
|
|||||||
|
|
||||||
train_loader, valid_loader = self.get_sliced_data()
|
train_loader, valid_loader = self.get_sliced_data()
|
||||||
|
|
||||||
|
# train_loader.dataset.fiber_out.to(self.pytorch_settings.device)
|
||||||
|
# train_loader.dataset.fiber_in.to(self.pytorch_settings.device)
|
||||||
|
|
||||||
optimizer_name = self.optimizer_settings.optimizer
|
optimizer_name = self.optimizer_settings.optimizer
|
||||||
|
|
||||||
# lr = self.optimizer_settings.learning_rate
|
# lr = self.optimizer_settings.learning_rate
|
||||||
@@ -1132,6 +1232,7 @@ class RegenerationTrainer:
|
|||||||
# except ValueError:
|
# except ValueError:
|
||||||
# pass
|
# pass
|
||||||
|
|
||||||
|
self.early_stop_vals = {"min_loss": float("inf"), "plateau_cnt": 0}
|
||||||
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
|
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
|
||||||
enable_progress = True
|
enable_progress = True
|
||||||
if enable_progress:
|
if enable_progress:
|
||||||
@@ -1147,9 +1248,48 @@ class RegenerationTrainer:
|
|||||||
epoch,
|
epoch,
|
||||||
enable_progress=enable_progress,
|
enable_progress=enable_progress,
|
||||||
)
|
)
|
||||||
|
if self.early_stop(loss):
|
||||||
|
self.save_model_checkpoints(epoch, loss)
|
||||||
|
break
|
||||||
if self.optimizer_settings.scheduler is not None:
|
if self.optimizer_settings.scheduler is not None:
|
||||||
self.scheduler.step(loss)
|
self.scheduler.step(loss)
|
||||||
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch)
|
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch)
|
||||||
|
self.save_model_checkpoints(epoch, loss)
|
||||||
|
self.writer.flush()
|
||||||
|
|
||||||
|
save_path = (Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar")
|
||||||
|
print(f"Training complete. Best checkpoint: {save_path}")
|
||||||
|
self.writer.close()
|
||||||
|
return self.best
|
||||||
|
|
||||||
|
def early_stop(self, loss):
|
||||||
|
# not stopping early at all
|
||||||
|
if not self.optimizer_settings.early_stopping:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# stopping because of abs threshold
|
||||||
|
if (loss_thr := self.optimizer_settings.early_stop_kwargs.get("threshold", None)) is not None:
|
||||||
|
if loss <= loss_thr:
|
||||||
|
print(f"Early stop: loss is below threshold ({loss:.2e} <= {loss_thr:.2e})")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# update vals
|
||||||
|
if loss < self.early_stop_vals["min_loss"]:
|
||||||
|
self.early_stop_vals["min_loss"] = loss
|
||||||
|
self.early_stop_vals["plateau_cnt"] = 0
|
||||||
|
return False
|
||||||
|
|
||||||
|
# stopping because of plateau
|
||||||
|
if (plateau_thresh := self.optimizer_settings.early_stop_kwargs.get("plateau", None)) is not None:
|
||||||
|
self.early_stop_vals["plateau_cnt"] += 1
|
||||||
|
if self.early_stop_vals["plateau_cnt"] >= plateau_thresh:
|
||||||
|
print(f"Early stop: loss plateau length over threshold ({self.early_stop_vals["plateau_cnt"]} >= {plateau_thresh})")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# no stop
|
||||||
|
return False
|
||||||
|
|
||||||
|
def save_model_checkpoints(self, epoch, loss):
|
||||||
if self.pytorch_settings.save_models and self.model is not None:
|
if self.pytorch_settings.save_models and self.model is not None:
|
||||||
save_path = (
|
save_path = (
|
||||||
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
|
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
|
||||||
@@ -1165,10 +1305,6 @@ class RegenerationTrainer:
|
|||||||
)
|
)
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
self.save_checkpoint(self.best, save_path)
|
self.save_checkpoint(self.best, save_path)
|
||||||
self.writer.flush()
|
|
||||||
|
|
||||||
self.writer.close()
|
|
||||||
return self.best
|
|
||||||
|
|
||||||
def _plot_model_response_powers(self, powers, layer_names, title_append="", subtitle="", show=True):
|
def _plot_model_response_powers(self, powers, layer_names, title_append="", subtitle="", show=True):
|
||||||
powers = [power / powers[0] for power in powers]
|
powers = [power / powers[0] for power in powers]
|
||||||
@@ -1190,6 +1326,77 @@ class RegenerationTrainer:
|
|||||||
plt.show()
|
plt.show()
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
def _plot_model_weights(self, model, title_append="", subtitle="", show=True):
|
||||||
|
model_params = []
|
||||||
|
plots = []
|
||||||
|
dims = []
|
||||||
|
for num, (layer_name, layer) in enumerate(model.named_children()):
|
||||||
|
onn_weights = layer.ONN.weight
|
||||||
|
onn_weights = onn_weights.detach().cpu().numpy()
|
||||||
|
onn_values = np.abs(onn_weights).real
|
||||||
|
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
|
||||||
|
|
||||||
|
model_params.append({layer_name: onn_weights})
|
||||||
|
plots.append({layer_name: (num, onn_values, onn_angles)})
|
||||||
|
dims.append(onn_weights.shape[0])
|
||||||
|
|
||||||
|
max_size = np.max(dims)
|
||||||
|
|
||||||
|
for plot in plots:
|
||||||
|
layer_name, (num, onn_values, onn_angles) = plot.popitem()
|
||||||
|
|
||||||
|
if num == 0:
|
||||||
|
value_img = onn_values
|
||||||
|
angle_img = onn_angles
|
||||||
|
onn_angles = pad_to_size(onn_angles, (max_size, None))
|
||||||
|
onn_values = pad_to_size(onn_values, (max_size, None))
|
||||||
|
else:
|
||||||
|
onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1))
|
||||||
|
onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1))
|
||||||
|
value_img = np.concatenate((value_img, onn_values), axis=1)
|
||||||
|
angle_img = np.concatenate((angle_img, onn_angles), axis=1)
|
||||||
|
|
||||||
|
value_img = np.ma.array(value_img, mask=np.isnan(value_img))
|
||||||
|
angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img))
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(18, 6.5))
|
||||||
|
fig.tight_layout()
|
||||||
|
|
||||||
|
dividers = map(make_axes_locatable, axs)
|
||||||
|
caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers))
|
||||||
|
|
||||||
|
masked_value_img = value_img
|
||||||
|
cmap = cm.batlow
|
||||||
|
cmap.set_bad(color="#AAAAAA")
|
||||||
|
im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1)
|
||||||
|
fig.colorbar(im_val, cax=caxs[0], orientation="vertical")
|
||||||
|
|
||||||
|
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
|
||||||
|
cmap = cm.romaO
|
||||||
|
cmap.set_bad(color="#AAAAAA")
|
||||||
|
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi)
|
||||||
|
cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)])
|
||||||
|
cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", "2π"])
|
||||||
|
|
||||||
|
|
||||||
|
axs[0].axis("off")
|
||||||
|
axs[1].axis("off")
|
||||||
|
|
||||||
|
axs[0].set_title("Values")
|
||||||
|
axs[1].set_title("Angles")
|
||||||
|
|
||||||
|
title = "Layer Weights"
|
||||||
|
if title_append:
|
||||||
|
title += f" {title_append}"
|
||||||
|
if subtitle:
|
||||||
|
title += f"\n{subtitle}"
|
||||||
|
fig.suptitle(title)
|
||||||
|
|
||||||
|
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
return fig
|
||||||
|
|
||||||
def _plot_model_response_eye(
|
def _plot_model_response_eye(
|
||||||
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
|
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
|
||||||
):
|
):
|
||||||
@@ -1354,7 +1561,7 @@ class RegenerationTrainer:
|
|||||||
|
|
||||||
data_settings_backup = copy.deepcopy(self.data_settings)
|
data_settings_backup = copy.deepcopy(self.data_settings)
|
||||||
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
||||||
self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
|
self.data_settings.drop_first = int(64 + random.randint(0, 1000))
|
||||||
self.data_settings.shuffle = False
|
self.data_settings.shuffle = False
|
||||||
self.data_settings.train_split = 1.0
|
self.data_settings.train_split = 1.0
|
||||||
self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols)
|
self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols)
|
||||||
@@ -1363,7 +1570,7 @@ class RegenerationTrainer:
|
|||||||
if isinstance(self.data_settings.config_path, (list, tuple))
|
if isinstance(self.data_settings.config_path, (list, tuple))
|
||||||
else self.data_settings.config_path
|
else self.data_settings.config_path
|
||||||
)
|
)
|
||||||
fiber_length = int(float(str(config_path).split("-")[4]) / 1000)
|
# fiber_length = int(float(str(config_path).split("-")[4]) / 1000)
|
||||||
if not hasattr(self, "_plot_loader"):
|
if not hasattr(self, "_plot_loader"):
|
||||||
self._plot_loader, _ = self.get_sliced_data(
|
self._plot_loader, _ = self.get_sliced_data(
|
||||||
override={
|
override={
|
||||||
@@ -1376,6 +1583,7 @@ class RegenerationTrainer:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
self._sps = self._plot_loader.dataset.samples_per_symbol
|
self._sps = self._plot_loader.dataset.samples_per_symbol
|
||||||
|
fiber_length = float(self._plot_loader.dataset.config["fiber"]["length"])/1000
|
||||||
self.data_settings = data_settings_backup
|
self.data_settings = data_settings_backup
|
||||||
self.pytorch_settings = pytorch_settings_backup
|
self.pytorch_settings = pytorch_settings_backup
|
||||||
|
|
||||||
@@ -1403,7 +1611,7 @@ class RegenerationTrainer:
|
|||||||
import gc
|
import gc
|
||||||
|
|
||||||
head_fig = self._plot_model_response_head(
|
head_fig = self._plot_model_response_head(
|
||||||
fiber_out_rot[: self.pytorch_settings.head_symbols * self._sps],
|
fiber_out[: self.pytorch_settings.head_symbols * self._sps],
|
||||||
fiber_in[: self.pytorch_settings.head_symbols * self._sps],
|
fiber_in[: self.pytorch_settings.head_symbols * self._sps],
|
||||||
regen[: self.pytorch_settings.head_symbols * self._sps],
|
regen[: self.pytorch_settings.head_symbols * self._sps],
|
||||||
angles[: self.pytorch_settings.head_symbols * self._sps],
|
angles[: self.pytorch_settings.head_symbols * self._sps],
|
||||||
@@ -1417,7 +1625,7 @@ class RegenerationTrainer:
|
|||||||
# raise NotImplementedError("Eye diagram not implemented")
|
# raise NotImplementedError("Eye diagram not implemented")
|
||||||
eye_fig = self._plot_model_response_eye(
|
eye_fig = self._plot_model_response_eye(
|
||||||
fiber_in[: self.pytorch_settings.eye_symbols * self._sps],
|
fiber_in[: self.pytorch_settings.eye_symbols * self._sps],
|
||||||
fiber_out_rot[: self.pytorch_settings.eye_symbols * self._sps],
|
fiber_out[: self.pytorch_settings.eye_symbols * self._sps],
|
||||||
regen[: self.pytorch_settings.eye_symbols * self._sps],
|
regen[: self.pytorch_settings.eye_symbols * self._sps],
|
||||||
timestamps=timestamps[: self.pytorch_settings.eye_symbols * self._sps],
|
timestamps=timestamps[: self.pytorch_settings.eye_symbols * self._sps],
|
||||||
labels=("fiber in", "fiber out", "regen"),
|
labels=("fiber in", "fiber out", "regen"),
|
||||||
@@ -1426,9 +1634,11 @@ class RegenerationTrainer:
|
|||||||
subtitle=subtitle,
|
subtitle=subtitle,
|
||||||
show=show,
|
show=show,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
weight_fig = self._plot_model_weights(model, title_append=title_append, subtitle=subtitle, show=show)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
return head_fig, eye_fig, power_fig
|
return head_fig, eye_fig, weight_fig, power_fig
|
||||||
|
|
||||||
def build_title(self, number: int):
|
def build_title(self, number: int):
|
||||||
title_append = f"epoch {number}"
|
title_append = f"epoch {number}"
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
import os
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -26,6 +28,28 @@ from hypertraining import models
|
|||||||
# constant_values=(-np.inf, -np.inf),
|
# constant_values=(-np.inf, -np.inf),
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
def register_puccs_cmap(puccs_path=None):
|
||||||
|
puccs_path = Path(__file__).resolve().parent / 'puccs.csv' if puccs_path is None else puccs_path
|
||||||
|
|
||||||
|
colors = []
|
||||||
|
# keys = None
|
||||||
|
with open(puccs_path, "r") as f:
|
||||||
|
for i, line in enumerate(f.readlines()):
|
||||||
|
elements = tuple(line.split(","))
|
||||||
|
# if i == 0:
|
||||||
|
# # keys = elements
|
||||||
|
# continue
|
||||||
|
# else:
|
||||||
|
try:
|
||||||
|
colors.append(tuple(map(float, elements[4:])))
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
# colors = []
|
||||||
|
# for current in puccs_csv_data:
|
||||||
|
# colors.append(tuple(current[4:]))
|
||||||
|
from matplotlib.colors import LinearSegmentedColormap
|
||||||
|
import matplotlib as mpl
|
||||||
|
mpl.colormaps.register(LinearSegmentedColormap.from_list('puccs', colors))
|
||||||
|
|
||||||
def pad_to_size(array, size):
|
def pad_to_size(array, size):
|
||||||
if not hasattr(size, "__len__"):
|
if not hasattr(size, "__len__"):
|
||||||
@@ -65,7 +89,7 @@ def pad_to_size(array, size):
|
|||||||
constant_values=(np.nan, np.nan),
|
constant_values=(np.nan, np.nan),
|
||||||
)
|
)
|
||||||
|
|
||||||
def model_plot(model_path):
|
def model_plot(model_path, show=True):
|
||||||
torch.serialization.add_safe_globals([
|
torch.serialization.add_safe_globals([
|
||||||
*util.complexNN.__all__,
|
*util.complexNN.__all__,
|
||||||
GlobalSettings,
|
GlobalSettings,
|
||||||
@@ -81,173 +105,113 @@ def model_plot(model_path):
|
|||||||
dims = checkpoint_dict["model_kwargs"].pop("dims")
|
dims = checkpoint_dict["model_kwargs"].pop("dims")
|
||||||
|
|
||||||
model = models.regenerator(*dims, **checkpoint_dict["model_kwargs"])
|
model = models.regenerator(*dims, **checkpoint_dict["model_kwargs"])
|
||||||
model.load_state_dict(checkpoint_dict["model_state_dict"])
|
model.load_state_dict(checkpoint_dict["model_state_dict"], strict=False)
|
||||||
|
|
||||||
model_params = []
|
model_params = []
|
||||||
plots = []
|
plots = []
|
||||||
max_size = np.max(dims)
|
max_size = np.max(dims)
|
||||||
# max_act_size = np.max(dims[1:])
|
# max_act_size = np.max(dims[1:])
|
||||||
|
|
||||||
angles = [None, None]
|
# angles = [None, None]
|
||||||
weights = [None, None]
|
# weights = [None, None]
|
||||||
|
|
||||||
for num, (layer_name, layer) in enumerate(model.named_children()):
|
for num, (layer_name, layer) in enumerate(model.named_children()):
|
||||||
# each layer contains an "ONN" layer and an "activation" layer
|
# each layer contains an "ONN" layer and an "activation" layer
|
||||||
# activation layer is approximately the same for all layers and nodes -> rotation by 90 degrees
|
# activation layer is approximately the same for all layers and nodes -> rotation by 90 degrees
|
||||||
onn_weights = layer.ONN.weight.T
|
onn_weights = layer.ONN.weight
|
||||||
onn_weights = onn_weights.detach().cpu().numpy()
|
onn_weights = onn_weights.detach().cpu().numpy()
|
||||||
onn_values = np.abs(onn_weights).real
|
onn_values = np.abs(onn_weights).real
|
||||||
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
|
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
|
||||||
|
|
||||||
|
|
||||||
act = layer.activation
|
|
||||||
|
|
||||||
act_values = np.ones((act.size, 1))
|
|
||||||
|
|
||||||
act_values = np.nan * act_values
|
|
||||||
|
|
||||||
act_angles = act.phase.unsqueeze(-1).detach().cpu().numpy()
|
|
||||||
...
|
|
||||||
# act_phi_bias = torch.pi * act.V_bias / (act.V_pi + 1e-8)
|
|
||||||
# act_phi_gain = torch.pi * (act.alpha * act.gain * act.responsivity) / (act.V_pi + 1e-8)
|
|
||||||
# xs = (0.01, 0.1, 1)
|
|
||||||
|
|
||||||
# act_values = np.zeros((act.size, len(xs)*2))
|
|
||||||
# act_angles = np.zeros((act.size, len(xs)*2))
|
|
||||||
|
|
||||||
# act_values[:,:] = np.nan
|
|
||||||
# act_angles[:,:] = np.nan
|
|
||||||
|
|
||||||
# for xi, x in enumerate(xs):
|
|
||||||
# phi_intermediate = act_phi_gain * x**2 + act_phi_bias
|
|
||||||
|
|
||||||
# act_resulting_gain = (
|
|
||||||
# 1j
|
|
||||||
# * torch.sqrt(1-act.alpha)
|
|
||||||
# * torch.exp(-0.5j * phi_intermediate)
|
|
||||||
# * torch.cos(0.5 * phi_intermediate)
|
|
||||||
# * x
|
|
||||||
# )
|
|
||||||
|
|
||||||
# act_resulting_gain = act_resulting_gain.detach().cpu().numpy()
|
|
||||||
# act_values[:, xi*2] = np.abs(act_resulting_gain).real
|
|
||||||
# act_angles[:, xi*2] = np.mod(np.angle(act_resulting_gain), 2*np.pi).real
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# if angles[0] is None or angles[0] > np.min(onn_angles.flatten()):
|
|
||||||
# angles[0] = np.min(onn_angles.flatten())
|
|
||||||
# if angles[1] is None or angles[1] < np.max(onn_angles.flatten()):
|
|
||||||
# angles[1] = np.max(onn_angles.flatten())
|
|
||||||
# if weights[0] is None or weights[0] > np.min(onn_weights.flatten()):
|
|
||||||
# weights[0] = np.min(onn_weights.flatten())
|
|
||||||
# if weights[1] is None or weights[1] < np.max(onn_weights.flatten()):
|
|
||||||
# weights[1] = np.max(onn_weights.flatten())
|
|
||||||
|
|
||||||
model_params.append({layer_name: onn_weights})
|
model_params.append({layer_name: onn_weights})
|
||||||
plots.append({layer_name: (num, onn_values, onn_angles, act_values, act_angles)})
|
plots.append({layer_name: (num, onn_values, onn_angles)})#, act_values, act_angles)})
|
||||||
|
|
||||||
# fig, axs = plt.subplots(3, len(model_params)*2-1, figsize=(20, 5))
|
# fig, axs = plt.subplots(3, len(model_params)*2-1, figsize=(20, 5))
|
||||||
|
|
||||||
for plot in plots:
|
for plot in plots:
|
||||||
layer_name, (num, onn_values, onn_angles, act_values, act_angles) = plot.popitem()
|
layer_name, (num, onn_values, onn_angles) = plot.popitem()
|
||||||
# for_plot[:, :, 0] = (for_plot[:, :, 0] - angles[0]) / (angles[1] - angles[0])
|
|
||||||
# for_plot[:, :, 1] = (for_plot[:, :, 1] - weights[0]) / (weights[1] - weights[0])
|
|
||||||
|
|
||||||
onn_values = np.ma.array(onn_values, mask=np.isnan(onn_values))
|
|
||||||
onn_values = onn_values - np.min(onn_values)
|
|
||||||
onn_values = onn_values / np.max(onn_values)
|
|
||||||
|
|
||||||
act_values = np.ma.array(act_values, mask=np.isnan(act_values))
|
|
||||||
act_values = act_values - np.min(act_values)
|
|
||||||
act_values = act_values / np.max(act_values)
|
|
||||||
|
|
||||||
|
|
||||||
onn_values = onn_values
|
|
||||||
onn_values = pad_to_size(onn_values, (max_size, None))
|
|
||||||
|
|
||||||
act_values = act_values
|
|
||||||
act_values = pad_to_size(act_values, (max_size, 3))
|
|
||||||
|
|
||||||
onn_angles = onn_angles / np.pi
|
|
||||||
onn_angles = pad_to_size(onn_angles, (max_size, None))
|
|
||||||
|
|
||||||
act_angles = act_angles / np.pi
|
|
||||||
act_angles = pad_to_size(act_angles, (max_size, 3))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# onn_angles = onn_angles - np.min(onn_angles)
|
|
||||||
# onn_angles = onn_angles / np.max(onn_angles)
|
|
||||||
|
|
||||||
# act_angles = act_angles - np.min(act_angles)
|
|
||||||
# act_angles = act_angles / np.max(act_angles)
|
|
||||||
|
|
||||||
if num == 0:
|
if num == 0:
|
||||||
value_img = np.concatenate((onn_values, act_values), axis=1)
|
value_img = onn_values
|
||||||
angle_img = np.concatenate((onn_angles, act_angles), axis=1)
|
angle_img = onn_angles
|
||||||
|
onn_angles = pad_to_size(onn_angles, (max_size, None))
|
||||||
|
onn_values = pad_to_size(onn_values, (max_size, None))
|
||||||
else:
|
else:
|
||||||
value_img = np.concatenate((value_img, onn_values, act_values), axis=1)
|
onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1))
|
||||||
angle_img = np.concatenate((angle_img, onn_angles, act_angles), axis=1)
|
onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1))
|
||||||
|
value_img = np.concatenate((value_img, onn_values), axis=1)
|
||||||
|
angle_img = np.concatenate((angle_img, onn_angles), axis=1)
|
||||||
|
|
||||||
|
value_img = np.ma.array(value_img, mask=np.isnan(value_img))
|
||||||
|
angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img))
|
||||||
|
|
||||||
|
# from cmcrameri import cm
|
||||||
|
from cmap import Colormap as cm
|
||||||
|
import scicomap as sc
|
||||||
|
# from matplotlib import colors as mcolors
|
||||||
|
# alpha_map = mcolors.LinearSegmentedColormap(
|
||||||
|
# 'alphamap',
|
||||||
|
# {
|
||||||
|
# 'red': [(0, 0, 0), (1, 0, 0)],
|
||||||
|
# 'green': [(0, 0, 0), (1, 0, 0)],
|
||||||
|
# 'blue': [(0, 0, 0), (1, 0, 0)],
|
||||||
|
# 'alpha': [
|
||||||
|
# (0, 1, 1),
|
||||||
|
# # (0.2, 0.2, 0.1),
|
||||||
|
# (1, 0, 0)
|
||||||
|
# ]
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
# alpha_map.set_bad(color="#AAAAAA")
|
||||||
|
|
||||||
|
|
||||||
|
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(16, 5))
|
||||||
# -np.inf to np.nan
|
# fig.tight_layout()
|
||||||
# value_img[value_img == -np.inf] = np.nan
|
dividers = map(make_axes_locatable, axs)
|
||||||
|
caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers))
|
||||||
# angle_img += move_to_location_in_size(onn_angles, ((max_size+3)*num, 0), img_overall_size)
|
|
||||||
# angle_img += move_to_location_in_size(act_angles, ((max_size+3)*(num+1) + 2, 0), img_overall_size)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from cmcrameri import cm
|
|
||||||
from matplotlib import colors as mcolors
|
|
||||||
alpha_map = mcolors.LinearSegmentedColormap(
|
|
||||||
'alphamap',
|
|
||||||
{
|
|
||||||
'red': [(0, 0, 0), (1, 0, 0)],
|
|
||||||
'green': [(0, 0, 0), (1, 0, 0)],
|
|
||||||
'blue': [(0, 0, 0), (1, 0, 0)],
|
|
||||||
'alpha': [
|
|
||||||
(0, 1, 1),
|
|
||||||
# (0.2, 0.2, 0.1),
|
|
||||||
(1, 0, 0)
|
|
||||||
]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
alpha_map.set_bad(color="#AAAAAA")
|
|
||||||
|
|
||||||
|
|
||||||
fig, axs = plt.subplots(3, 1, sharex=True, sharey=True, figsize=(7, 8.5))
|
|
||||||
fig.tight_layout()
|
|
||||||
# masked_value_img = np.ma.masked_where(np.isnan(value_img), value_img)
|
# masked_value_img = np.ma.masked_where(np.isnan(value_img), value_img)
|
||||||
masked_value_img = value_img
|
masked_value_img = value_img
|
||||||
cmap = cm.batlowW
|
cmap = cm('google:turbo').to_matplotlib()
|
||||||
|
# cmap = sc.ScicoSequential("rainbow").get_mpl_color_map()
|
||||||
cmap.set_bad(color="#AAAAAA")
|
cmap.set_bad(color="#AAAAAA")
|
||||||
im_val = axs[0].imshow(masked_value_img, cmap=cmap)
|
im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1)
|
||||||
|
fig.colorbar(im_val, cax=caxs[0], orientation="vertical")
|
||||||
|
|
||||||
|
|
||||||
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
|
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
|
||||||
cmap = cm.romaO
|
# cmap = cm('crameri:romao').to_matplotlib()
|
||||||
|
# cmap = plt.get_cmap('puccs')
|
||||||
|
# cmap = sc.ScicoCircular("colorwheel").get_mpl_color_map()
|
||||||
|
cmap = cm('colorcet:CET_C8').to_matplotlib()
|
||||||
cmap.set_bad(color="#AAAAAA")
|
cmap.set_bad(color="#AAAAAA")
|
||||||
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap)
|
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi)
|
||||||
im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap)
|
cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)])
|
||||||
im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map)
|
cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", "2π"])
|
||||||
|
# im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap)
|
||||||
|
# im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map)
|
||||||
|
|
||||||
axs[0].axis("off")
|
axs[0].axis("off")
|
||||||
axs[1].axis("off")
|
axs[1].axis("off")
|
||||||
axs[2].axis("off")
|
# axs[2].axis("off")
|
||||||
|
|
||||||
axs[0].set_title("Values")
|
axs[0].set_title("Values")
|
||||||
axs[1].set_title("Angles")
|
axs[1].set_title("Angles")
|
||||||
axs[2].set_title("Values and Angles")
|
# axs[2].set_title("Values and Angles")
|
||||||
|
|
||||||
|
|
||||||
...
|
...
|
||||||
|
if show:
|
||||||
plt.show()
|
plt.show()
|
||||||
|
return fig
|
||||||
|
|
||||||
# model = models.regenerator(*dims, **model_kwargs)
|
# model = models.regenerator(*dims, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model_plot(".models/best_20250105_145719.tar")
|
register_puccs_cmap()
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
model_plot(sys.argv[1])
|
||||||
|
else:
|
||||||
|
print("Please provide a model path as an argument")
|
||||||
|
# model_plot(".models/best_20250114_224234.tar")
|
||||||
|
|||||||
102
src/single-core-regen/puccs.csv
Normal file
102
src/single-core-regen/puccs.csv
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
"x","L","a","b","R","G","B"
|
||||||
|
0.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
|
||||||
|
0.01,0.5374499525557803,0.604014067614707,0.6777967519386492,0.8956274406155226,0.27553288030331824,0.
|
||||||
|
0.02,0.5560867887452998,0.5680836759482211,0.6855816828789898,0.9019507507843885,0.318608215541461,0.
|
||||||
|
0.03,0.5746877595125583,0.5322224300667823,0.6933516322080414,0.907905487190649,0.3580633000693721,0.
|
||||||
|
0.04,0.5932314662487472,0.49647158484797804,0.7010976613543587,0.9134808162089558,0.3949845524063657,0.
|
||||||
|
0.05,0.6117000836392819,0.46086550613202343,0.7088123243737041,0.918668356138916,0.43002019316005363,0.
|
||||||
|
0.06,0.6300828534995973,0.4254249348741487,0.7164911273850869,0.923462736751354,0.4635961938811463,0.
|
||||||
|
0.07,0.6483763163456417,0.3901565406944371,0.7241326253017896,0.9278609626724071,0.49601354353255284,0.
|
||||||
|
0.08,0.6665840140182806,0.3550534951951814,0.7317382976124045,0.9318616057744784,0.5274983630587982,0.
|
||||||
|
0.09,0.6847162776119433,0.3200958808181962,0.7393124597949372,0.9354640163365924,0.5582303922647159,0.
|
||||||
|
0.1,0.7027902128942014,0.2852507189547545,0.7468622572263107,0.9386675557407496,0.5883604892249517,0.004034952213848706
|
||||||
|
0.11,0.7208298719332069,0.25047163906104203,0.7543977368741345,0.9414708123927996,0.6180221032545026,0.016031521294251994
|
||||||
|
0.12,0.7388665670611175,0.2156982733607376,0.7619319784446927,0.943870754968487,0.6473392272576862,0.029857267582036696
|
||||||
|
0.13,0.7569392765472108,0.18085547473834482,0.7694812638396673,0.9458617774020323,0.676432172396153,0.045365670193636125
|
||||||
|
0.14,0.7750950944867471,0.14585244938794778,0.7770652650825484,0.9474345911958609,0.7054219201084561,0.06017985923530026
|
||||||
|
0.15,0.793389684293558,0.11058188251425949,0.7847072337503834,0.9485749196617762,0.7344334940032564,0.07418869502646075
|
||||||
|
0.16,0.8117919447684838,0.07510373484536464,0.792394178330817,0.9492596163836376,0.7634480277996188,0.08767517868137237
|
||||||
|
0.17,0.8293050962981561,0.03629277424762101,0.799038155466063,0.9462308253550155,0.7922009241807345,0.10066327128139077
|
||||||
|
0.18,0.8213303100752708,-0.0062517290795987,0.7879999288492758,0.9088702681901394,0.7940579017644396,0.10139639009534024
|
||||||
|
0.19,0.8134831311534617,-0.048115463155645855,0.7771383286984362,0.8716809050191757,0.7954897210083888,0.10232311621802098
|
||||||
|
0.2,0.80558613530069,-0.0902449644291895,0.7662077749032042,0.8337524177888596,0.7965471523787845,0.10344968926026826
|
||||||
|
0.21,0.7975860185564765,-0.13292460297117392,0.7551344872795225,0.7947193410849823,0.7972381033243311,0.10477682283894393
|
||||||
|
0.22,0.7894147026971006,-0.17651756772919341,0.7438242359834689,0.7540941866826836,0.7975605026647324,0.10631182441371936
|
||||||
|
0.23,0.7809997374598548,-0.2214103719409295,0.7321767396537806,0.7112894518675287,0.7974995317311054,0.1080672415170634
|
||||||
|
0.24,0.7722646970273015,-0.2680107379394189,0.7200862142018722,0.6655745739336695,0.7970267795229349,0.11006041388465265
|
||||||
|
0.25,0.7631307298557146,-0.3167393290089981,0.7074435179925446,0.6160047476007512,0.7960993904970947,0.11231257117602686
|
||||||
|
0.26,0.7535192192483822,-0.36801555555407994,0.6941398344519211,0.5612859274945571,0.794659599537827,0.11484733363789801
|
||||||
|
0.27,0.7433557597838075,-0.42223636134393283,0.6800721760037781,0.4994862901720824,0.7926351396848288,0.11768844813479104
|
||||||
|
0.28,0.732575139048096,-0.479749646583324,0.6651502794883674,0.42731393423789277,0.7899410218414098,0.12085678487511567
|
||||||
|
0.29,0.7211269294461059,-0.5408244362880141,0.6493043460161184,0.3378265607222193,0.786483110019224,0.124366774034814
|
||||||
|
0.3,0.7090756028785993,-0.6051167807996883,0.6326236137723747,0.2098475715121697,0.7821998608677176,0.12819222127525928
|
||||||
|
0.31,0.7094510768540225,-0.6165036055456403,0.5630307498747129,0.15061488620640032,0.7845112116922692,0.21943537230975235
|
||||||
|
0.32,0.7174669421288304,-0.5917687864932311,0.4797229624661701,0.18766933782916642,0.7905828987725732,0.31091344246312086
|
||||||
|
0.33,0.7249009746435938,-0.5688293479200438,0.40246208306061504,0.21160609617940718,0.7962175427587832,0.38519766326885596
|
||||||
|
0.34,0.7317072855135611,-0.5478268906666535,0.3317250285377912,0.22717569971119178,0.8013847719431052,0.4490960048955565
|
||||||
|
0.35,0.7379328517830899,-0.5286164561226088,0.26702357292455026,0.23690087622812972,0.8061220291668977,0.5056371468159843
|
||||||
|
0.36,0.7436229063122554,-0.5110584677642499,0.20788761731555405,0.24226377668817778,0.8104638164122776,0.5563570758573497
|
||||||
|
0.37,0.7488251728809415,-0.4950056627547577,0.15382117501783654,0.24424372086048424,0.8144455902164638,0.6022301663745243
|
||||||
|
0.38,0.7535943992285348,-0.48028910419451787,0.10425526029155024,0.24352232677523483,0.818107753931944,0.6440238320299774
|
||||||
|
0.39,0.757994865186593,-0.4667104416936734,0.05852182167144754,0.240562414747303,0.8214980148949816,0.6824536572462205
|
||||||
|
0.4,0.7620994844391137,-0.4540446830999986,0.015863077249098356,0.2356325204239052,0.8246710357361025,0.7182393675419642
|
||||||
|
0.41,0.7659871096124125,-0.4420485102716773,-0.024540477496154123,0.22880568593963535,0.8276865975886148,0.7521146815529202
|
||||||
|
0.42,0.7697410958994951,-0.4304647113488041,-0.06355514164248566,0.21993360985514526,0.8306086550266585,0.7848331944479765
|
||||||
|
0.43,0.773446484628189,-0.4190308715098135,-0.10206473803580057,0.20858849290850018,0.833503273690861,0.8171544357676854
|
||||||
|
0.44,0.7771893686864673,-0.4074813310994203,-0.14096401824224686,0.1939295692427068,0.8364382500400466,0.8498448067259188
|
||||||
|
0.45,0.7810574093604746,-0.3955455908045306,-0.18116403397486242,0.17438366103820427,0.839483669055626,0.8836865023336339
|
||||||
|
0.46,0.7851360804917298,-0.3829599011818591,-0.2235531031349741,0.14679145002531463,0.8427091517444469,0.9194481212717681
|
||||||
|
0.47,0.789525027020907,-0.369416784561489,-0.26916682191206776,0.10278921007810798,0.8461971304126237,0.9580316568065935
|
||||||
|
0.48,0.7942371698732826,-0.35487637041943493,-0.3181394757087982,0.0013920913109500188,0.8499626968466341,0.9995866371771526
|
||||||
|
0.49,0.7773897680996302,-0.31852357140025195,-0.34537976514700053,0.10740420703601522,0.8254781216972907,1.
|
||||||
|
0.5,0.7604011244310231,-0.28211213216592784,-0.3722846952738428,0.1581725581872408,0.8008522647497104,1.
|
||||||
|
0.51,0.7433440454962605,-0.2455540169176899,-0.3992980063927199,0.19300141807932156,0.7761561224913385,1.
|
||||||
|
0.52,0.7262590833969331,-0.20893614020926626,-0.42635547610418184,0.2194621842292243,0.751443124097109,1.
|
||||||
|
0.53,0.709058602701224,-0.17207067467417486,-0.453595892719742,0.2405673704012788,0.7265803324554873,1.
|
||||||
|
0.54,0.6915768892539101,-0.1346024482921609,-0.48128169789479536,0.25788347992973676,0.701321051230534,1.
|
||||||
|
0.55,0.6736331627810209,-0.09614399811510127,-0.5096991935104321,0.2722888922216317,0.6753950894563805,1.
|
||||||
|
0.56,0.6551463184003872,-0.05652149358027936,-0.5389768254408652,0.28422807900785235,0.6486730893521468,1.
|
||||||
|
0.57,0.6361671326276888,-0.01584376303510615,-0.5690341788729347,0.293907374075009,0.6212117649042732,1.
|
||||||
|
0.58,0.6168396823565967,0.025580396234342995,-0.5996430791016598,0.301442767979156,0.5931976878638505,1.
|
||||||
|
0.59,0.5973210287815495,0.06741435793529688,-0.6305547881733555,0.30694603901024253,0.5648312189065924,1.
|
||||||
|
0.6,0.5777303704171711,0.10940264614179468,-0.661580531294122,0.3105418468883679,0.5362525958007331,1.
|
||||||
|
0.61,0.5581475370499237,0.15137416317967575,-0.6925938819599547,0.3123531986526998,0.5075386530652202,1.
|
||||||
|
0.62,0.5386227795100639,0.19322120739317136,-0.7235152578861672,0.31248922600720636,0.4787151440558522,1.
|
||||||
|
0.63,0.5191666876024412,0.23492108185347996,-0.754327887989376,0.31103663081260624,0.44973844514160927,1.
|
||||||
|
0.64,0.4996990584326256,0.2766456839100268,-0.7851587896650079,0.30803814950244496,0.4204116611935119,1.
|
||||||
|
0.65,0.479957679121191,0.3189570094767831,-0.8164232296840259,0.30343473603466015,0.390226489453496,1.
|
||||||
|
0.66,0.4600072725872886,0.3617163391430824,-0.8480187063016573,0.29717122075330515,0.3591178757512998,1.
|
||||||
|
0.67,0.44600100870220305,0.4113853615984094,-0.8697728377551008,0.3178994129506999,0.3295740682997879,1.
|
||||||
|
0.68,0.4574651571354146,0.44026390446569547,-0.8504539292487465,0.3842479358768364,0.3280946443367561,1.
|
||||||
|
0.69,0.4691809168948424,0.46977626401045774,-0.830711015748157,0.44293649140770447,0.3260767554252525,1.
|
||||||
|
0.7,0.4811696900083858,0.49997635259991063,-0.8105080314416201,0.49708450874457527,0.3234487047238236,1.
|
||||||
|
0.71,0.49350094811609174,0.5310391714342613,-0.7897279055963483,0.5485591109413528,0.3201099534066949,1.
|
||||||
|
0.72,0.5062548753068121,0.5631667067020758,-0.7682355153041539,0.5985798481027601,0.3159263917472715,1.
|
||||||
|
0.73,0.5195243020949684,0.5965928013272943,-0.7458744264238399,0.6480500606439057,0.31071717884730565,1.
|
||||||
|
0.74,0.5334043922713477,0.6315571758288618,-0.7224842728734379,0.6976685401842261,0.3042411890803418,1.
|
||||||
|
0.75,0.5479805812358602,0.6682750446095802,-0.697921082452685,0.7479712773579563,0.29618040787504757,1.
|
||||||
|
0.76,0.5633244502526606,0.7069267230777347,-0.6720642293775535,0.7993701361353484,0.28611136999256687,1.
|
||||||
|
0.77,0.5794956601139,0.7476624986056212,-0.6448131757501174,0.8521918014427678,0.2734527325942473,1.
|
||||||
|
0.78,0.5965429098573916,0.7906050455688622,-0.6160858559672187,0.9067003897516911,0.2573693489198746,1.
|
||||||
|
0.79,0.6145761476424179,0.8360313267658297,-0.5856969899409387,0.963334644317004,0.23648492980159264,1.
|
||||||
|
0.8,0.6232910688128902,0.859291371252556,-0.5300995185388214,1.,0.21867949406239662,0.9712088595948508
|
||||||
|
0.81,0.6159984336377875,0.8439887543380684,-0.44635440435952856,1.,0.21606849746358275,0.9041480210597966
|
||||||
|
0.82,0.6091642745073532,0.8296481879180277,-0.36787420852419694,1.,0.21421830096504035,0.8419706002336461
|
||||||
|
0.83,0.6025478038652375,0.8157644115969636,-0.2918938425681935,1.,0.21295365915197917,0.7823908751330636
|
||||||
|
0.84,0.5961857222953111,0.8024144366282877,-0.21883475834162458,0.9971140114799418,0.21220068235083267,0.7256713129328118
|
||||||
|
0.85,0.5900921771070883,0.7896279492437488,-0.1488594167412921,0.993273906363258,0.2118788857127918,0.671860243327784
|
||||||
|
0.86,0.5842771639541229,0.7774259239818333,-0.08208260304413262,0.9887084084529413,0.21191070453347688,0.6209624706933893
|
||||||
|
0.87,0.578741582584259,0.7658102488427286,-0.018514649521559012,0.9835846378805114,0.2122246941077346,0.5728987835613306
|
||||||
|
0.88,0.5734741590353537,0.7547572669288056,0.04197390858426542,0.9780378159372328,0.21275878699579343,0.5274829957183049
|
||||||
|
0.89,0.5684517008574971,0.7442183119942206,0.09964940221121898,0.9721670725313721,0.21346242315895625,0.4844270603851604
|
||||||
|
0.9,0.5636419856510335,0.7341257696545772,0.15488185789614228,0.9660363209686843,0.21429691147008262,0.4433660148378527
|
||||||
|
0.91,0.5590069340453534,0.7243997354573974,0.20810856081277884,0.9596781387247791,0.2152344151262528,0.4038812338146013
|
||||||
|
0.92,0.5545051525321143,0.7149533506766244,0.25980485409830323,0.9530986696850675,0.21625626438013962,0.3655130449917989
|
||||||
|
0.93,0.5500961975299247,0.705701749880514,0.3104351723857584,0.9462863346513658,0.21735046958786286,0.327780364198278
|
||||||
|
0.94,0.545740378056064,0.6965616468647046,0.36045530782708896,0.93921469089265,0.21851014470332586,0.29014917175372823
|
||||||
|
0.95,0.5414004092067859,0.6874548042588865,0.41029342232076466,0.9318478255642132,0.21973168075163751,0.2519897371806688
|
||||||
|
0.96,0.5370416605957644,0.6783085548415655,0.46034719456417006,0.9241434776436454,0.22101341980094052,0.2124579038400577
|
||||||
|
0.97,0.5326309593934517,0.6690532898786764,0.5109975653738162,0.9160532016485884,0.22235495330179011,0.17018252385769012
|
||||||
|
0.98,0.5281374148557197,0.6596241892863608,0.5625992691950712,0.90752576202319,0.22375597459867458,0.1223073280126531
|
||||||
|
0.99,0.5235317096396147,0.6499597345521199,0.615488972291106,0.8985077346125597,0.22521565729028564,0.05933950582860665
|
||||||
|
1.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
|
||||||
|
@@ -26,28 +26,39 @@ global_settings = GlobalSettings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
data_settings = DataSettings(
|
data_settings = DataSettings(
|
||||||
# config_path="data/*-128-16384-1-0-0-0-0-PAM4-0-0.ini",
|
# config_path="data/20250115-114319-128-16384-inf-100000-0-0-0-0-PAM4-1-0-10.ini", # no effects enabled - baseline
|
||||||
config_path="data/20250110-190528-128-16384-100000-0-0.2-17.0-0.058-PAM4-0-0.14-10.ini",
|
# config_path = "data/20250115-233553-128-16384-1060.0-100000-0-0.2-17.0-0.058-PAM4-1.0-0.0-10.ini", # dispersion + slope only
|
||||||
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
# config_path="data/20250115-115836-128-16384-60.0-100000-0-0.2-17-0.058-PAM4-1000-0.2-10.ini", # all linear effects enabled with realistic values + noise + pmd (delta_beta=0.2) + ortho_error = 0.1
|
||||||
|
# config_path="data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # a)
|
||||||
|
# config_path="data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # b)
|
||||||
|
# config_path="data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # c)
|
||||||
|
# config_path="data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # d) 10ps dgd
|
||||||
|
config_path="data/20250120-105720-128-16384-inf-100000-0-0.2-17-0.058-PAM4-0-0-10.ini", # d) 10ns
|
||||||
|
|
||||||
|
# config_path="data/20250114-215547-128-16384-60.0-100000-1.15-0.2-17-0.058-PAM4-1-0-10.ini", # with gamma=1.15, 2.5dBm launch power, no pmd
|
||||||
|
|
||||||
|
|
||||||
dtype="complex64",
|
dtype="complex64",
|
||||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||||
symbols=4, # study: single_core_regen_20241123_011232
|
symbols=4, # study: single_core_regen_20241123_011232 -> taps spread over 4 symbols @ 10GBd
|
||||||
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
|
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
|
||||||
output_size=20, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
output_size=20, # = model_input_dim/2 study: single_core_regen_20241123_011232
|
||||||
shuffle=True,
|
shuffle=False,
|
||||||
drop_first=64,
|
drop_first=256,
|
||||||
|
drop_last=256,
|
||||||
train_split=0.8,
|
train_split=0.8,
|
||||||
randomise_polarisations=False,
|
randomise_polarisations=False,
|
||||||
polarisations=True,
|
polarisations=False,
|
||||||
|
# cross_pol_interference=0.01,
|
||||||
osnr=16, #16dB due to amplification with NF 5
|
osnr=16, #16dB due to amplification with NF 5
|
||||||
)
|
)
|
||||||
|
|
||||||
pytorch_settings = PytorchSettings(
|
pytorch_settings = PytorchSettings(
|
||||||
epochs=1000,
|
epochs=1000,
|
||||||
batchsize=2**14,
|
batchsize=2**13,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dataloader_workers=24,
|
dataloader_workers=32,
|
||||||
dataloader_prefetch=8,
|
dataloader_prefetch=4,
|
||||||
summary_dir=".runs",
|
summary_dir=".runs",
|
||||||
write_every=2**5,
|
write_every=2**5,
|
||||||
save_models=True,
|
save_models=True,
|
||||||
@@ -65,16 +76,13 @@ model_settings = ModelSettings(
|
|||||||
# "n_hidden_nodes_3": 4,
|
# "n_hidden_nodes_3": 4,
|
||||||
# "n_hidden_nodes_4": 2,
|
# "n_hidden_nodes_4": 2,
|
||||||
},
|
},
|
||||||
model_activation_func="phase_shift",
|
model_activation_func="EOActivation",
|
||||||
dropout_prob=0,
|
dropout_prob=0,
|
||||||
model_layer_function="ONNRect",
|
model_layer_function="ONNRect",
|
||||||
model_layer_kwargs={"square": True},
|
model_layer_kwargs={"square": True},
|
||||||
scale=2.0,
|
scale=2.0,
|
||||||
model_layer_parametrizations=[
|
model_layer_parametrizations=[
|
||||||
{
|
# EOactivation
|
||||||
"tensor_name": "weight",
|
|
||||||
"parametrization": util.complexNN.energy_conserving,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"tensor_name": "alpha",
|
"tensor_name": "alpha",
|
||||||
"parametrization": util.complexNN.clamp,
|
"parametrization": util.complexNN.clamp,
|
||||||
@@ -83,54 +91,20 @@ model_settings = ModelSettings(
|
|||||||
"max": 1,
|
"max": 1,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
# ONNRect
|
||||||
{
|
{
|
||||||
"tensor_name": "gain",
|
"tensor_name": "weight",
|
||||||
"parametrization": util.complexNN.clamp,
|
"parametrization": torch.nn.utils.parametrizations.orthogonal,
|
||||||
"kwargs": {
|
|
||||||
"min": 0,
|
|
||||||
"max": None,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tensor_name": "phase_bias",
|
|
||||||
"parametrization": util.complexNN.clamp,
|
|
||||||
"kwargs": {
|
|
||||||
"min": 0,
|
|
||||||
"max": 2 * torch.pi,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
|
# Scale
|
||||||
{
|
{
|
||||||
"tensor_name": "scale",
|
"tensor_name": "scale",
|
||||||
"parametrization": util.complexNN.clamp,
|
"parametrization": util.complexNN.clamp,
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"min": 0,
|
"min": 0,
|
||||||
"max": 2,
|
"max": 10,
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tensor_name": "angle",
|
|
||||||
"parametrization": util.complexNN.clamp,
|
|
||||||
"kwargs": {
|
|
||||||
"min": -torch.pi,
|
|
||||||
"max": torch.pi,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
# {
|
|
||||||
# "tensor_name": "scale",
|
|
||||||
# "parametrization": util.complexNN.clamp,
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "tensor_name": "bias",
|
|
||||||
# "parametrization": util.complexNN.clamp,
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "tensor_name": "V",
|
|
||||||
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
|
|
||||||
# },
|
|
||||||
{
|
|
||||||
"tensor_name": "loss",
|
|
||||||
"parametrization": util.complexNN.clamp,
|
|
||||||
},
|
},
|
||||||
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -145,107 +119,19 @@ optimizer_settings = OptimizerSettings(
|
|||||||
scheduler="ReduceLROnPlateau",
|
scheduler="ReduceLROnPlateau",
|
||||||
scheduler_kwargs={
|
scheduler_kwargs={
|
||||||
"patience": 2**6,
|
"patience": 2**6,
|
||||||
"factor": 0.75,
|
"factor": 0.5,
|
||||||
# "threshold": 1e-3,
|
# "threshold": 1e-3,
|
||||||
"min_lr": 1e-6,
|
"min_lr": 1e-6,
|
||||||
"cooldown": 10,
|
"cooldown": 10,
|
||||||
},
|
},
|
||||||
|
early_stopping=True,
|
||||||
|
early_stop_kwargs={
|
||||||
|
"threshold": 1e-06,
|
||||||
|
"plateau": 2**7,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_dict_to_file(dictionary, filename):
|
|
||||||
"""
|
|
||||||
Save the best dictionary to a JSON file.
|
|
||||||
|
|
||||||
:param best: Dictionary containing the best training results.
|
|
||||||
:type best: dict
|
|
||||||
:param filename: Path to the JSON file where the dictionary will be saved.
|
|
||||||
:type filename: str
|
|
||||||
"""
|
|
||||||
with open(filename, "w") as f:
|
|
||||||
json.dump(dictionary, f, indent=4)
|
|
||||||
|
|
||||||
|
|
||||||
def sweep_lengths(*lengths, model=None, data_glob: str = None, strategy="newest"):
|
|
||||||
assert model is not None, "Model must be provided."
|
|
||||||
assert data_glob is not None, "Data glob must be provided."
|
|
||||||
model = model
|
|
||||||
|
|
||||||
fiber_ins = {}
|
|
||||||
fiber_outs = {}
|
|
||||||
regens = {}
|
|
||||||
timestampss = {}
|
|
||||||
|
|
||||||
trainer = RegenerationTrainer(
|
|
||||||
checkpoint_path=model,
|
|
||||||
)
|
|
||||||
trainer.define_model()
|
|
||||||
|
|
||||||
for length in lengths:
|
|
||||||
data_glob_length = data_glob.replace("{length}", str(length))
|
|
||||||
files = list(Path.cwd().glob(data_glob_length))
|
|
||||||
if len(files) == 0:
|
|
||||||
continue
|
|
||||||
if strategy == "newest":
|
|
||||||
sorted_kwargs = {
|
|
||||||
"key": lambda x: x.stat().st_mtime,
|
|
||||||
"reverse": True,
|
|
||||||
}
|
|
||||||
elif strategy == "oldest":
|
|
||||||
sorted_kwargs = {
|
|
||||||
"key": lambda x: x.stat().st_mtime,
|
|
||||||
"reverse": False,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown strategy {strategy}.")
|
|
||||||
file = sorted(files, **sorted_kwargs)[0]
|
|
||||||
|
|
||||||
loader, _ = trainer.get_sliced_data(override={"config_path": file})
|
|
||||||
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
|
|
||||||
|
|
||||||
fiber_ins[length] = fiber_in
|
|
||||||
fiber_outs[length] = fiber_out
|
|
||||||
regens[length] = regen
|
|
||||||
timestampss[length] = timestamps
|
|
||||||
|
|
||||||
data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0])
|
|
||||||
channel_names = ["" for _ in range(2 * len(timestampss.keys()) + 2)]
|
|
||||||
|
|
||||||
data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128
|
|
||||||
data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square()
|
|
||||||
|
|
||||||
channel_names[1] = "fiber in x"
|
|
||||||
|
|
||||||
for li, length in enumerate(timestampss.keys()):
|
|
||||||
data[2 + 2 * li, 0, :] = timestampss[length] / 128
|
|
||||||
data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
|
|
||||||
data[2 + 2 * li + 1, 0, :] = timestampss[length] / 128
|
|
||||||
data[2 + 2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
|
|
||||||
|
|
||||||
channel_names[2 + 2 * li + 1] = f"regen x {length}"
|
|
||||||
channel_names[2 + 2 * li] = f"fiber out x {length}"
|
|
||||||
|
|
||||||
# get current backend
|
|
||||||
backend = matplotlib.get_backend()
|
|
||||||
|
|
||||||
matplotlib.use("TkCairo")
|
|
||||||
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
|
|
||||||
|
|
||||||
print_attrs = ("channel_name", "success", "min_area")
|
|
||||||
with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}):
|
|
||||||
for result in eye.eye_stats:
|
|
||||||
print_dict = {attr: result[attr] for attr in print_attrs}
|
|
||||||
rprint(print_dict)
|
|
||||||
rprint()
|
|
||||||
|
|
||||||
eye.plot(all_stats=False)
|
|
||||||
matplotlib.use(backend)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# lengths = range(90000, 100000+10000, 10000)
|
|
||||||
# lengths = [100000]
|
|
||||||
# sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest")
|
|
||||||
|
|
||||||
trainer = RegenerationTrainer(
|
trainer = RegenerationTrainer(
|
||||||
global_settings=global_settings,
|
global_settings=global_settings,
|
||||||
@@ -253,83 +139,15 @@ if __name__ == "__main__":
|
|||||||
pytorch_settings=pytorch_settings,
|
pytorch_settings=pytorch_settings,
|
||||||
model_settings=model_settings,
|
model_settings=model_settings,
|
||||||
optimizer_settings=optimizer_settings,
|
optimizer_settings=optimizer_settings,
|
||||||
# checkpoint_path=".models/best_20250104_191428.tar",
|
checkpoint_path=".models/best_20250117_144001.tar",
|
||||||
reset_epoch=True,
|
new_model=True,
|
||||||
# settings_override={
|
settings_override={
|
||||||
# "data_settings": {
|
"data_settings": data_settings.__dict__,
|
||||||
# "config_path": "data/20241229-163*-128-16384-100000-*.ini",
|
|
||||||
# "polarisations": True,
|
|
||||||
# },
|
|
||||||
# "model_settings": {
|
|
||||||
# "scale": 2.0,
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
# "optimizer_settings": {
|
# "optimizer_settings": {
|
||||||
# "optimizer_kwargs": {
|
# "early_stop_kwargs":{
|
||||||
# "lr": 0.01,
|
# "plateau": 2**8,
|
||||||
# },
|
|
||||||
# }
|
# }
|
||||||
# }
|
# }
|
||||||
# 20241202_143149
|
}
|
||||||
)
|
)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
# from hypertraining.lighning_models import regenerator, regeneratorData
|
|
||||||
# import lightning as L
|
|
||||||
|
|
||||||
# model = regenerator(
|
|
||||||
# 2 * data_settings.output_size,
|
|
||||||
# *model_settings.overrides["hidden_layer_dims"],
|
|
||||||
# model_settings.output_dim,
|
|
||||||
# layer_function=getattr(util.complexNN, model_settings.model_layer_function),
|
|
||||||
# layer_func_kwargs=model_settings.model_layer_kwargs,
|
|
||||||
# act_function=getattr(util.complexNN, model_settings.model_activation_func),
|
|
||||||
# act_func_kwargs=None,
|
|
||||||
# parametrizations=model_settings.model_layer_parametrizations,
|
|
||||||
# dtype=getattr(torch, data_settings.dtype),
|
|
||||||
# dropout_prob=model_settings.dropout_prob,
|
|
||||||
# scale_layers=model_settings.scale,
|
|
||||||
# optimizer=getattr(torch.optim, optimizer_settings.optimizer),
|
|
||||||
# optimizer_kwargs=optimizer_settings.optimizer_kwargs,
|
|
||||||
# lr_scheduler=getattr(torch.optim.lr_scheduler, optimizer_settings.scheduler),
|
|
||||||
# lr_scheduler_kwargs=optimizer_settings.scheduler_kwargs,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# dm = regeneratorData(
|
|
||||||
# config_globs=data_settings.config_path,
|
|
||||||
# output_symbols=data_settings.symbols,
|
|
||||||
# output_dim=data_settings.output_size,
|
|
||||||
# dtype=getattr(torch, data_settings.dtype),
|
|
||||||
# drop_first=data_settings.drop_first,
|
|
||||||
# shuffle=data_settings.shuffle,
|
|
||||||
# train_split=data_settings.train_split,
|
|
||||||
# batch_size=pytorch_settings.batchsize,
|
|
||||||
# loader_settings={
|
|
||||||
# "num_workers": pytorch_settings.dataloader_workers,
|
|
||||||
# "prefetch_factor": pytorch_settings.dataloader_prefetch,
|
|
||||||
# "pin_memory": True,
|
|
||||||
# "drop_last": True,
|
|
||||||
# },
|
|
||||||
# seed=global_settings.seed,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # writer = L.SummaryWriter(pytorch_settings.summary_dir + f"/{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
|
||||||
|
|
||||||
# # from torch.utils.tensorboard import SummaryWriter
|
|
||||||
|
|
||||||
# subdir = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
||||||
|
|
||||||
# # writer = SummaryWriter(pytorch_settings.summary_dir + f"/{subdir}")
|
|
||||||
|
|
||||||
# logger = L.pytorch.loggers.TensorBoardLogger(pytorch_settings.summary_dir, name=subdir, log_graph=True)
|
|
||||||
|
|
||||||
# trainer = L.Trainer(
|
|
||||||
# fast_dev_run=False,
|
|
||||||
# # max_epochs=pytorch_settings.epochs,
|
|
||||||
# max_epochs=2,
|
|
||||||
# enable_checkpointing=True,
|
|
||||||
# default_root_dir=f".models/{subdir}/",
|
|
||||||
# logger=logger,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# trainer.fit(model, dm)
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ Full license text in LICENSE file
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import configparser
|
import configparser
|
||||||
|
# import copy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -40,7 +41,7 @@ alpha = 0.2
|
|||||||
D = 17
|
D = 17
|
||||||
S = 0.058
|
S = 0.058
|
||||||
bireflength = 10
|
bireflength = 10
|
||||||
max_delta_beta = 0.14
|
pmd_q = 0.2
|
||||||
; birefseed = 0xC0FFEE
|
; birefseed = 0xC0FFEE
|
||||||
|
|
||||||
[signal]
|
[signal]
|
||||||
@@ -195,10 +196,14 @@ class pam_generator:
|
|||||||
|
|
||||||
|
|
||||||
def initialize_fiber_and_data(config):
|
def initialize_fiber_and_data(config):
|
||||||
|
f0 = config["glova"].get("f0", None)
|
||||||
|
if f0 is None:
|
||||||
|
f0 = 299792458/(config["glova"].get("lambda0", 1550)*1e-9)
|
||||||
|
config["glova"]["f0"] = f0
|
||||||
py_glova = pypho.setup(
|
py_glova = pypho.setup(
|
||||||
nos=config["glova"]["nos"],
|
nos=config["glova"]["nos"],
|
||||||
sps=config["glova"]["sps"],
|
sps=config["glova"]["sps"],
|
||||||
f0=config["glova"]["f0"],
|
f0=f0,
|
||||||
symbolrate=config["glova"]["symbolrate"],
|
symbolrate=config["glova"]["symbolrate"],
|
||||||
wisdom_dir=config["glova"]["wisdom_dir"],
|
wisdom_dir=config["glova"]["wisdom_dir"],
|
||||||
flags=config["glova"]["flags"],
|
flags=config["glova"]["flags"],
|
||||||
@@ -216,7 +221,9 @@ def initialize_fiber_and_data(config):
|
|||||||
symbolsrc = pypho.symbols(
|
symbolsrc = pypho.symbols(
|
||||||
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
|
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
|
||||||
)
|
)
|
||||||
laser = pypho.lasmod(py_glova, power=config["signal"]["laser_power"], Df=0, theta=np.pi / 4)
|
laserx = pypho.lasmod(py_glova, power=0, Df=0, theta=np.pi/4)
|
||||||
|
# lasery = pypho.lasmod(py_glova, power=0, Df=25, theta=0)
|
||||||
|
|
||||||
modulator = pam_generator(
|
modulator = pam_generator(
|
||||||
py_glova,
|
py_glova,
|
||||||
mod_depth=config["signal"]["mod_depth"],
|
mod_depth=config["signal"]["mod_depth"],
|
||||||
@@ -232,7 +239,12 @@ def initialize_fiber_and_data(config):
|
|||||||
symbols_y[:3] = 0
|
symbols_y[:3] = 0
|
||||||
# symbols_x += 1
|
# symbols_x += 1
|
||||||
|
|
||||||
cw = laser()
|
|
||||||
|
cw = laserx()
|
||||||
|
# cwy = lasery()
|
||||||
|
# cw[0]['E'][0] = cw[0]['E'][0]
|
||||||
|
# cw[0]['E'][1] = cwy[0]['E'][0]
|
||||||
|
|
||||||
|
|
||||||
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
|
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
|
||||||
|
|
||||||
@@ -251,13 +263,41 @@ def initialize_fiber_and_data(config):
|
|||||||
|
|
||||||
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
|
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
|
||||||
|
|
||||||
|
## side channels
|
||||||
|
# df = 100
|
||||||
|
# signal_power = pypho.functions.W_to_dBm(np.sum(pypho.functions.getpower_W(source_signal[0]["E"])))
|
||||||
|
|
||||||
|
|
||||||
|
# symbols_x_side = symbolsrc(pattern="random")
|
||||||
|
# symbols_y_side = symbolsrc(pattern="random")
|
||||||
|
# symbols_x_side[:3] = 0
|
||||||
|
# symbols_y_side[:3] = 0
|
||||||
|
|
||||||
|
# cw_left = laser(Df=-df)
|
||||||
|
# source_signal_left = modulator(E=cw_left, symbols=(symbols_x_side, symbols_y_side))
|
||||||
|
|
||||||
|
# cw_right = laser(Df=df)
|
||||||
|
# source_signal_right = modulator(E=cw_right, symbols=(symbols_y_side, symbols_x_side))
|
||||||
|
|
||||||
|
E_in_pure = source_signal[0]["E"]
|
||||||
|
|
||||||
nf = py_edfa.NF
|
nf = py_edfa.NF
|
||||||
source_signal = py_edfa(E=source_signal, NF=0)
|
pmean = py_edfa.Pmean
|
||||||
py_edfa.NF = nf
|
|
||||||
|
# ideal amplification to launch power into fiber
|
||||||
|
source_signal = py_edfa(E=source_signal, NF=0, Pmean=config["signal"]["laser_power"])
|
||||||
|
# source_signal_left = py_edfa(E=source_signal_left, NF=0, Pmean=config["signal"]["laser_power"])
|
||||||
|
# source_signal_right = py_edfa(E=source_signal_right, NF=0, Pmean=config["signal"]["laser_power"])
|
||||||
|
|
||||||
|
# source_signal[0]["E"][0] += source_signal_left[0]["E"][0] + source_signal_right[0]["E"][0]
|
||||||
|
# source_signal[0]["E"][1] += source_signal_left[0]["E"][1] + source_signal_right[0]["E"][1]
|
||||||
|
|
||||||
c_data.E_in = source_signal[0]["E"]
|
c_data.E_in = source_signal[0]["E"]
|
||||||
noise = source_signal[0]["noise"]
|
noise = source_signal[0]["noise"]
|
||||||
|
|
||||||
|
py_edfa.NF = nf
|
||||||
|
py_edfa.Pmean = pmean
|
||||||
|
|
||||||
py_fiber = pypho.fiber(
|
py_fiber = pypho.fiber(
|
||||||
glova=py_glova,
|
glova=py_glova,
|
||||||
l=config["fiber"]["length"],
|
l=config["fiber"]["length"],
|
||||||
@@ -265,20 +305,29 @@ def initialize_fiber_and_data(config):
|
|||||||
gamma=config["fiber"]["gamma"],
|
gamma=config["fiber"]["gamma"],
|
||||||
D=config["fiber"]["d"],
|
D=config["fiber"]["d"],
|
||||||
S=config["fiber"]["s"],
|
S=config["fiber"]["s"],
|
||||||
|
phi_max=0.02,
|
||||||
)
|
)
|
||||||
if config["fiber"].get("birefsteps", 0) > 0:
|
|
||||||
|
config["fiber"]["birefsteps"] = config["fiber"].get(
|
||||||
|
"birefsteps", config["fiber"]["length"] // config["fiber"].get("bireflength", config["fiber"]["length"])
|
||||||
|
)
|
||||||
|
if config["fiber"]["birefsteps"] > 0:
|
||||||
|
config["fiber"]["bireflength"] = config["fiber"].get("bireflength", config["fiber"]["length"] / config["fiber"]["birefsteps"])
|
||||||
seed = config["fiber"].get("birefseed", (int(time.time() * 1000)) % 2**32)
|
seed = config["fiber"].get("birefseed", (int(time.time() * 1000)) % 2**32)
|
||||||
py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre(
|
py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre(
|
||||||
py_fiber.l,
|
config["fiber"]["length"],
|
||||||
py_fiber.l / config["fiber"]["birefsteps"],
|
config["fiber"]["bireflength"],
|
||||||
# maxDeltaD=config["fiber"]["d"]/5,
|
maxDeltaBeta=config["fiber"].get("pmd_q", 0)/np.sqrt(2*config["fiber"]["bireflength"]),
|
||||||
maxDeltaBeta=config["fiber"].get("max_delta_beta", 0),
|
|
||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200)
|
elif (dgd := config['fiber'].get('dgd', 0)) > 0:
|
||||||
|
py_fiber.birefarray = [
|
||||||
|
pypho.birefringence_segment(z_point=0, angle=np.pi/2, delta_beta=1000*dgd/config["fiber"]["length"])
|
||||||
|
]
|
||||||
|
c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=config["fiber"]["length"] if py_fiber.gamma == 0 else 200)
|
||||||
c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova)
|
c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova)
|
||||||
|
|
||||||
return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y)
|
return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y), py_glova, E_in_pure
|
||||||
|
|
||||||
|
|
||||||
def save_data(data, config, **metadata):
|
def save_data(data, config, **metadata):
|
||||||
@@ -316,8 +365,11 @@ def save_data(data, config, **metadata):
|
|||||||
f"D = {config['fiber']['d']}",
|
f"D = {config['fiber']['d']}",
|
||||||
f"S = {config['fiber']['s']}",
|
f"S = {config['fiber']['s']}",
|
||||||
f"birefsteps = {config['fiber'].get('birefsteps', 0)}",
|
f"birefsteps = {config['fiber'].get('birefsteps', 0)}",
|
||||||
f"max_delta_beta = {config['fiber'].get('max_delta_beta', 0)}",
|
f"pmd_q = {config['fiber'].get('pmd_q', 0)}",
|
||||||
f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set",
|
f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set",
|
||||||
|
f"dgd = {config['fiber'].get('dgd', 0)}",
|
||||||
|
f"ortho_error = {config['fiber'].get('ortho_error', 0)}",
|
||||||
|
f"pol_error = {config['fiber'].get('pol_error', 0)}",
|
||||||
"",
|
"",
|
||||||
"[signal]",
|
"[signal]",
|
||||||
f"seed = {hex(seed)}" if seed else "; seed = not set",
|
f"seed = {hex(seed)}" if seed else "; seed = not set",
|
||||||
@@ -346,24 +398,12 @@ def save_data(data, config, **metadata):
|
|||||||
save_file = f"{config_hash}.h5"
|
save_file = f"{config_hash}.h5"
|
||||||
config_content += f'"{str(save_file)}"\n'
|
config_content += f'"{str(save_file)}"\n'
|
||||||
|
|
||||||
filename_components = (
|
config_filename:Path = create_config_filename(config, data_dir, timestamp)
|
||||||
timestamp.strftime("%Y%m%d-%H%M%S"),
|
while config_filename.exists():
|
||||||
config["glova"]["sps"],
|
time.sleep(1)
|
||||||
config["glova"]["nos"],
|
config_filename = create_config_filename(config, data_dir=data_dir)
|
||||||
config["signal"]["osnr"],
|
|
||||||
config["fiber"]["length"],
|
|
||||||
config["fiber"]["gamma"],
|
|
||||||
config["fiber"]["alpha"],
|
|
||||||
config["fiber"]["d"],
|
|
||||||
config["fiber"]["s"],
|
|
||||||
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
|
|
||||||
config["fiber"].get("birefsteps", 0),
|
|
||||||
config["fiber"].get("max_delta_beta", 0),
|
|
||||||
int(config["glova"]["symbolrate"] / 1e9),
|
|
||||||
)
|
|
||||||
|
|
||||||
lookup_file = "-".join(map(str, filename_components)) + ".ini"
|
|
||||||
config_filename = data_dir / lookup_file
|
|
||||||
with open(config_filename, "w") as f:
|
with open(config_filename, "w") as f:
|
||||||
f.write(config_content)
|
f.write(config_content)
|
||||||
|
|
||||||
@@ -376,11 +416,31 @@ def save_data(data, config, **metadata):
|
|||||||
outfile.attrs[key] = value
|
outfile.attrs[key] = value
|
||||||
# np.save(save_dir / save_file, save_data)
|
# np.save(save_dir / save_file, save_data)
|
||||||
|
|
||||||
print("Saved config to", config_filename)
|
# print("Saved config to", config_filename)
|
||||||
print("Saved data to", save_dir / save_file)
|
# print("Saved data to", save_dir / save_file)
|
||||||
|
|
||||||
return config_filename
|
return config_filename
|
||||||
|
|
||||||
|
def create_config_filename(config, data_dir:Path, timestamp=None):
|
||||||
|
if timestamp is None:
|
||||||
|
timestamp = datetime.now()
|
||||||
|
filename_components = (
|
||||||
|
timestamp.strftime("%Y%m%d-%H%M%S"),
|
||||||
|
config["glova"]["sps"],
|
||||||
|
config["glova"]["nos"],
|
||||||
|
config["signal"]["osnr"],
|
||||||
|
config["fiber"]["length"],
|
||||||
|
config["fiber"]["gamma"],
|
||||||
|
config["fiber"]["alpha"],
|
||||||
|
config["fiber"]["d"],
|
||||||
|
config["fiber"]["s"],
|
||||||
|
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
|
||||||
|
config["fiber"].get("birefsteps", 0),
|
||||||
|
config["fiber"].get("pmd_q", 0),
|
||||||
|
int(config["glova"]["symbolrate"] / 1e9),
|
||||||
|
)
|
||||||
|
lookup_file = "-".join(map(str, filename_components)) + ".ini"
|
||||||
|
return data_dir / lookup_file
|
||||||
|
|
||||||
def length_loop(config, lengths, save=True):
|
def length_loop(config, lengths, save=True):
|
||||||
lengths = sorted(lengths)
|
lengths = sorted(lengths)
|
||||||
@@ -388,7 +448,7 @@ def length_loop(config, lengths, save=True):
|
|||||||
print(f"\nGenerating data for fiber length {length}m")
|
print(f"\nGenerating data for fiber length {length}m")
|
||||||
config["fiber"]["length"] = length
|
config["fiber"]["length"] = length
|
||||||
|
|
||||||
cfiber, cdata, noise, edfa = initialize_fiber_and_data(config)
|
cfiber, cdata, noise, edfa, symbols, py_glova = initialize_fiber_and_data(config)
|
||||||
|
|
||||||
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
|
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
|
||||||
cfiber()
|
cfiber()
|
||||||
@@ -416,51 +476,49 @@ def single_run_with_plot(config, save=True):
|
|||||||
in_out_eyes(cfiber, cdata, show_pols=False)
|
in_out_eyes(cfiber, cdata, show_pols=False)
|
||||||
return config_filename
|
return config_filename
|
||||||
|
|
||||||
def single_run(config, save=True):
|
|
||||||
cfiber, cdata, noise, edfa, symbols = initialize_fiber_and_data(config)
|
|
||||||
|
|
||||||
# mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
|
def single_run(config, save=True, silent=True):
|
||||||
# print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)")
|
cfiber, cdata, noise, edfa, symbols, glova, E_in = initialize_fiber_and_data(config)
|
||||||
|
|
||||||
# estimate osnr
|
|
||||||
# noise_power = np.mean(noise)
|
|
||||||
# osnr_lin = mean_power_in / noise_power - 1
|
|
||||||
# osnr = 10 * np.log10(osnr_lin)
|
|
||||||
# print(f"Estimated OSNR: {osnr:.3f} dB")
|
|
||||||
|
|
||||||
|
# transmit
|
||||||
cfiber()
|
cfiber()
|
||||||
|
|
||||||
# mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
# amplify
|
||||||
# print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)")
|
|
||||||
|
|
||||||
# noise = noise * np.exp(-cfiber.params.l * cfiber.params.alpha)
|
|
||||||
|
|
||||||
# estimate osnr
|
|
||||||
# noise_power = np.mean(noise)
|
|
||||||
# osnr_lin = mean_power_out / noise_power - 1
|
|
||||||
# osnr = 10 * np.log10(osnr_lin)
|
|
||||||
# print(f"Estimated OSNR: {osnr:.3f} dB")
|
|
||||||
|
|
||||||
E_tmp = [{"E": cdata.E_out, "noise": noise}]
|
E_tmp = [{"E": cdata.E_out, "noise": noise}]
|
||||||
|
|
||||||
E_tmp = edfa(E=E_tmp)
|
E_tmp = edfa(E=E_tmp)
|
||||||
|
|
||||||
|
|
||||||
|
# rotate
|
||||||
|
# ortho error
|
||||||
|
ortho_error = config["fiber"].get("ortho_error", 0)
|
||||||
|
|
||||||
|
E_tmp[0]["E"] = np.stack((
|
||||||
|
E_tmp[0]["E"][0] * np.cos(ortho_error/2) + E_tmp[0]["E"][1] * np.sin(ortho_error/2),
|
||||||
|
E_tmp[0]["E"][0] * np.sin(ortho_error/2) + E_tmp[0]["E"][1] * np.cos(ortho_error/2)
|
||||||
|
), axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
pol_error = config['fiber'].get('pol_error', 0)
|
||||||
|
|
||||||
|
E_tmp[0]["E"] = np.stack((
|
||||||
|
E_tmp[0]["E"][0] * np.cos(pol_error) - E_tmp[0]["E"][1] * np.sin(pol_error),
|
||||||
|
E_tmp[0]["E"][0] * np.sin(pol_error) + E_tmp[0]["E"][1] * np.cos(pol_error)
|
||||||
|
), axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# output
|
||||||
cdata.E_out = E_tmp[0]["E"]
|
cdata.E_out = E_tmp[0]["E"]
|
||||||
# noise = E_tmp[0]["noise"]
|
|
||||||
|
|
||||||
# mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
|
||||||
|
|
||||||
# print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)")
|
|
||||||
|
|
||||||
# estimate osnr
|
|
||||||
# noise_power = np.mean(noise)
|
|
||||||
# osnr_lin = mean_power_amp / noise_power - 1
|
|
||||||
# osnr = 10 * np.log10(osnr_lin)
|
|
||||||
# print(f"Estimated OSNR: {osnr:.3f} dB")
|
|
||||||
|
|
||||||
config_filename = None
|
config_filename = None
|
||||||
symbols = np.array(symbols)
|
symbols = np.array(symbols)
|
||||||
if save:
|
if save:
|
||||||
config_filename = save_data(cdata, config, **{"symbols": symbols})
|
config_filename = save_data(cdata, config, **{"symbols": symbols})
|
||||||
return cfiber,cdata,config_filename
|
if not silent:
|
||||||
|
print(f"Saved config to {config_filename}")
|
||||||
|
return cfiber, cdata, config_filename
|
||||||
|
|
||||||
|
|
||||||
def in_out_eyes(cfiber, cdata, show_pols=False):
|
def in_out_eyes(cfiber, cdata, show_pols=False):
|
||||||
|
|||||||
138
src/single-core-regen/testing/prob_dens.ipynb
Normal file
138
src/single-core-regen/testing/prob_dens.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -441,8 +441,7 @@ class input_rotator(nn.Module):
|
|||||||
# return out
|
# return out
|
||||||
|
|
||||||
|
|
||||||
#### as defined by zhang et al
|
#### as defined by zhang et alas
|
||||||
|
|
||||||
|
|
||||||
class DropoutComplex(nn.Module):
|
class DropoutComplex(nn.Module):
|
||||||
def __init__(self, p=0.5):
|
def __init__(self, p=0.5):
|
||||||
@@ -464,7 +463,7 @@ class Scale(nn.Module):
|
|||||||
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
|
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x * self.scale
|
return x * torch.sqrt(self.scale)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Scale({self.size})"
|
return f"Scale({self.size})"
|
||||||
@@ -546,35 +545,31 @@ class EOActivation(nn.Module):
|
|||||||
raise ValueError("Size must be specified")
|
raise ValueError("Size must be specified")
|
||||||
self.size = size
|
self.size = size
|
||||||
self.alpha = nn.Parameter(torch.rand(size))
|
self.alpha = nn.Parameter(torch.rand(size))
|
||||||
self.V_bias = nn.Parameter(torch.rand(size))
|
|
||||||
self.gain = nn.Parameter(torch.rand(size))
|
self.gain = nn.Parameter(torch.rand(size))
|
||||||
# if bias:
|
self.V_bias = nn.Parameter(torch.rand(size))
|
||||||
# self.phase_bias = nn.Parameter(torch.zeros(size))
|
# self.register_buffer("gain", torch.ones(size))
|
||||||
# else:
|
# self.register_buffer("responsivity", torch.ones(size))
|
||||||
# self.register_buffer("phase_bias", torch.zeros(size))
|
# self.register_buffer("V_pi", torch.ones(size))
|
||||||
# self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
|
|
||||||
self.register_buffer("responsivity", torch.ones(size)*0.9)
|
|
||||||
self.register_buffer("V_pi", torch.ones(size)*3)
|
|
||||||
|
|
||||||
self.reset_weights()
|
self.reset_weights()
|
||||||
|
|
||||||
def reset_weights(self):
|
def reset_weights(self):
|
||||||
if "alpha" in self._parameters:
|
if "alpha" in self._parameters:
|
||||||
self.alpha.data = torch.rand(self.size)
|
self.alpha.data = torch.rand(self.size)
|
||||||
if "V_pi" in self._parameters:
|
# if "V_pi" in self._parameters:
|
||||||
self.V_pi.data = torch.rand(self.size)*3
|
# self.V_pi.data = torch.rand(self.size)*3
|
||||||
if "V_bias" in self._parameters:
|
if "V_bias" in self._parameters:
|
||||||
self.V_bias.data = torch.randn(self.size)
|
self.V_bias.data = torch.randn(self.size)
|
||||||
if "gain" in self._parameters:
|
if "gain" in self._parameters:
|
||||||
self.gain.data = torch.rand(self.size)
|
self.gain.data = torch.rand(self.size)
|
||||||
if "responsivity" in self._parameters:
|
# if "responsivity" in self._parameters:
|
||||||
self.responsivity.data = torch.ones(self.size)*0.9
|
# self.responsivity.data = torch.ones(self.size)*0.9
|
||||||
# if "bias" in self._parameters:
|
# if "bias" in self._parameters:
|
||||||
# self.phase_bias.data = torch.zeros(self.size)
|
# self.phase_bias.data = torch.zeros(self.size)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
|
phi_b = torch.pi * self.V_bias# / (self.V_pi)
|
||||||
g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8)
|
g_phi = torch.pi * (self.alpha * self.gain)# * self.responsivity)# / (self.V_pi)
|
||||||
intermediate = g_phi * x.abs().square() + phi_b
|
intermediate = g_phi * x.abs().square() + phi_b
|
||||||
return (
|
return (
|
||||||
1j
|
1j
|
||||||
|
|||||||
105
src/single-core-regen/util/core.py
Normal file
105
src/single-core-regen/util/core.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
|
||||||
|
# This software is licensed according to the "BSD 2-clause" license.
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import h5py
|
||||||
|
import numpy as _np
|
||||||
|
from scipy.interpolate import interp1d as _interp1d
|
||||||
|
from scipy.ndimage import gaussian_filter as _gaussian_filter
|
||||||
|
from ._brescount import bres_curve_count as _bres_curve_count
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['grid_count']
|
||||||
|
|
||||||
|
|
||||||
|
def grid_count(y, window_size, offset=0, size=None, fuzz=True, blur=0, bounds=None):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
`y` is the 1-d array of signal samples.
|
||||||
|
|
||||||
|
`window_size` is the number of samples to show horizontally in the
|
||||||
|
eye diagram. Typically this is twice the number of samples in a
|
||||||
|
"symbol" (i.e. in a data bit).
|
||||||
|
|
||||||
|
`offset` is the number of initial samples to skip before computing
|
||||||
|
the eye diagram. This allows the overall phase of the diagram to
|
||||||
|
be adjusted.
|
||||||
|
|
||||||
|
`size` must be a tuple of two integers. It sets the size of the
|
||||||
|
array of counts, (height, width). The default is (800, 640).
|
||||||
|
|
||||||
|
`fuzz`: If True, the values in `y` are reinterpolated with a
|
||||||
|
random "fuzz factor" before plotting in the eye diagram. This
|
||||||
|
reduces an aliasing-like effect that arises with the use of
|
||||||
|
Bresenham's algorithm.
|
||||||
|
|
||||||
|
`bounds` must be a tuple of two floating point values, (ymin, ymax).
|
||||||
|
These set the y range of the returned array. If not given, the
|
||||||
|
bounds are `(y.min() - 0.05*A, y.max() + 0.05*A)`, where `A` is
|
||||||
|
`y.max() - y.min()`.
|
||||||
|
|
||||||
|
Return Value
|
||||||
|
------------
|
||||||
|
Returns a numpy array of integers.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# hash input params
|
||||||
|
param_ob = (y, window_size, offset, size, fuzz, blur, bounds)
|
||||||
|
param_hash = hashlib.md5(str(param_ob).encode()).hexdigest()
|
||||||
|
cache_dir = Path.home()/".eyediagram"/".cache"
|
||||||
|
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
if (cache_dir/param_hash).is_file():
|
||||||
|
try:
|
||||||
|
with h5py.File(cache_dir/param_hash, "r") as infile:
|
||||||
|
counts = infile["counts"][:]
|
||||||
|
if counts.len() != 0:
|
||||||
|
return counts
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if size is None:
|
||||||
|
size = (800, 640)
|
||||||
|
height, width = size
|
||||||
|
dt = width / window_size
|
||||||
|
counts = _np.zeros((width, height), dtype=_np.int32)
|
||||||
|
|
||||||
|
if bounds is None:
|
||||||
|
ymin = y.min()
|
||||||
|
ymax = y.max()
|
||||||
|
yamp = ymax - ymin
|
||||||
|
ymin = ymin - 0.05*yamp
|
||||||
|
ymax = ymax + 0.05*yamp
|
||||||
|
ymax = _np.ceil(ymax*10)/10
|
||||||
|
ymin = _np.floor(ymin*10)/10
|
||||||
|
else:
|
||||||
|
ymin, ymax = bounds
|
||||||
|
|
||||||
|
start = offset
|
||||||
|
while start + window_size < len(y):
|
||||||
|
end = start + window_size
|
||||||
|
yy = y[start:end+1]
|
||||||
|
k = _np.arange(len(yy))
|
||||||
|
xx = dt*k
|
||||||
|
if fuzz:
|
||||||
|
f = _interp1d(xx, yy, kind='cubic')
|
||||||
|
jiggle = dt*(_np.random.beta(a=3, b=3, size=len(xx)-2) - 0.5)
|
||||||
|
xx[1:-1] += jiggle
|
||||||
|
yd = f(xx)
|
||||||
|
else:
|
||||||
|
yd = yy
|
||||||
|
iyd = (height * (yd - ymin)/(ymax - ymin)).astype(_np.int32)
|
||||||
|
_bres_curve_count(xx.astype(_np.int32), iyd, counts)
|
||||||
|
|
||||||
|
start = end
|
||||||
|
|
||||||
|
if blur != 0:
|
||||||
|
counts = _gaussian_filter(counts, sigma=blur)
|
||||||
|
|
||||||
|
with h5py.File(cache_dir/param_hash, "w") as outfile:
|
||||||
|
outfile.create_dataset("data", data=counts)
|
||||||
|
|
||||||
|
return counts
|
||||||
@@ -25,13 +25,14 @@ import multiprocessing as mp
|
|||||||
# def __len__(self):
|
# def __len__(self):
|
||||||
# return len(self.indices)
|
# return len(self.indices)
|
||||||
|
|
||||||
|
|
||||||
def load_from_file(datapath):
|
def load_from_file(datapath):
|
||||||
if str(datapath).endswith('.h5'):
|
if str(datapath).endswith(".h5"):
|
||||||
symbols = None
|
symbols = None
|
||||||
with h5py.File(datapath, "r") as infile:
|
with h5py.File(datapath, "r") as infile:
|
||||||
data = infile["data"][:]
|
data = infile["data"][:]
|
||||||
try:
|
try:
|
||||||
symbols = infile["symbols"][:]
|
symbols = np.swapaxes(infile["symbols"][:], 0, 1)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@@ -40,7 +41,7 @@ def load_from_file(datapath):
|
|||||||
return data, symbols
|
return data, symbols
|
||||||
|
|
||||||
|
|
||||||
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=1, device=None, dtype=None):
|
def load_data(config_path, skipfirst=0, skiplast=0, symbols=None, real=False, normalize=1, device=None, dtype=None):
|
||||||
filepath = Path(config_path)
|
filepath = Path(config_path)
|
||||||
filepath = filepath.parent.glob(filepath.name)
|
filepath = filepath.parent.glob(filepath.name)
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
@@ -58,12 +59,20 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=1, d
|
|||||||
|
|
||||||
data, orig_symbols = load_from_file(datapath)
|
data, orig_symbols = load_from_file(datapath)
|
||||||
|
|
||||||
data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)]
|
data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps - skiplast * sps)]
|
||||||
orig_symbols = orig_symbols[skipfirst:symbols+skipfirst]
|
orig_symbols = orig_symbols[skipfirst : symbols + skipfirst - skiplast]
|
||||||
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps))
|
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps - skiplast * sps))
|
||||||
|
|
||||||
data *= np.sqrt(normalize)
|
data *= np.sqrt(normalize)
|
||||||
|
|
||||||
|
launch_power = float(config["signal"]["laser_power"])
|
||||||
|
output_power = float(config["signal"]["edfa_power"])
|
||||||
|
|
||||||
|
target_normalization = 10 ** (output_power / 10) / 10 ** (launch_power / 10)
|
||||||
|
# target_normalization *= 0.5 # allow 50% power loss, so the network can ignore parts of the signal
|
||||||
|
|
||||||
|
data[:, 0:2] *= np.sqrt(target_normalization)
|
||||||
|
|
||||||
# if normalize:
|
# if normalize:
|
||||||
# # square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
# # square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
||||||
# a, b, c, d = data.T
|
# a, b, c, d = data.T
|
||||||
@@ -132,13 +141,15 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
target_delay: float | int = 0,
|
target_delay: float | int = 0,
|
||||||
xy_delay: float | int = 0,
|
xy_delay: float | int = 0,
|
||||||
drop_first: float | int = 0,
|
drop_first: float | int = 0,
|
||||||
|
drop_last=0,
|
||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
real: bool = False,
|
real: bool = False,
|
||||||
device=None,
|
device=None,
|
||||||
# osnr: float|None = None,
|
# osnr: float|None = None,
|
||||||
polarisations = None,
|
polarisations=None,
|
||||||
randomise_polarisations: bool = False,
|
randomise_polarisations: bool = False,
|
||||||
repeat_randoms: int = 1,
|
repeat_randoms: int = 1,
|
||||||
|
# cross_pol_interference: float = 0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -172,6 +183,7 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
assert drop_first >= 0, "drop_first must be non-negative"
|
assert drop_first >= 0, "drop_first must be non-negative"
|
||||||
|
|
||||||
self.randomise_polarisations = randomise_polarisations
|
self.randomise_polarisations = randomise_polarisations
|
||||||
|
# self.cross_pol_interference = cross_pol_interference
|
||||||
|
|
||||||
data_raw = None
|
data_raw = None
|
||||||
self.config = None
|
self.config = None
|
||||||
@@ -181,6 +193,7 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
data, config, orig_syms = load_data(
|
data, config, orig_syms = load_data(
|
||||||
file_path,
|
file_path,
|
||||||
skipfirst=drop_first,
|
skipfirst=drop_first,
|
||||||
|
skiplast=drop_last,
|
||||||
symbols=kwargs.get("num_symbols", None),
|
symbols=kwargs.get("num_symbols", None),
|
||||||
real=real,
|
real=real,
|
||||||
normalize=1000,
|
normalize=1000,
|
||||||
@@ -300,20 +313,18 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
# fiber_out: [E_out_x, E_out_y, timestamps]
|
# fiber_out: [E_out_x, E_out_y, timestamps]
|
||||||
|
|
||||||
# add noise related to amplification necessary due to splitting of the signal
|
# add noise related to amplification necessary due to splitting of the signal
|
||||||
gain_lin = output_dim*2
|
# gain_lin = output_dim*2
|
||||||
edfa_nf = float(self.config["signal"]["edfa_nf"])
|
# gain_lin = 1
|
||||||
nf_lin = 10**(edfa_nf/10)
|
# edfa_nf = float(self.config["signal"]["edfa_nf"])
|
||||||
f0 = float(self.config["glova"]["f0"])
|
# nf_lin = 10**(edfa_nf/10)
|
||||||
|
# f0 = float(self.config["glova"]["f0"])
|
||||||
noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9
|
|
||||||
|
|
||||||
noise = torch.randn_like(fiber_out[:2, :])
|
|
||||||
noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1))
|
|
||||||
noise = noise * torch.sqrt(noise_add / noise_power)
|
|
||||||
fiber_out[:2, :] += noise
|
|
||||||
|
|
||||||
|
|
||||||
|
# noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9
|
||||||
|
|
||||||
|
# noise = torch.randn_like(fiber_out[:2, :])
|
||||||
|
# noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1))
|
||||||
|
# noise = noise * torch.sqrt(noise_add / noise_power)
|
||||||
|
# fiber_out[:2, :] += noise
|
||||||
|
|
||||||
# if osnr is None:
|
# if osnr is None:
|
||||||
# noisy = fiber_out[:2, :]
|
# noisy = fiber_out[:2, :]
|
||||||
@@ -324,7 +335,6 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
|
|
||||||
# fiber_out: [E_out_x, E_out_y, timestamps, E_out_x_noisy, E_out_y_noisy]
|
# fiber_out: [E_out_x, E_out_y, timestamps, E_out_x_noisy, E_out_y_noisy]
|
||||||
|
|
||||||
|
|
||||||
if repeat_randoms > 1:
|
if repeat_randoms > 1:
|
||||||
fiber_in = fiber_in.repeat(1, 1, repeat_randoms)
|
fiber_in = fiber_in.repeat(1, 1, repeat_randoms)
|
||||||
fiber_out = fiber_out.repeat(1, 1, repeat_randoms)
|
fiber_out = fiber_out.repeat(1, 1, repeat_randoms)
|
||||||
@@ -334,8 +344,9 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
|
|
||||||
if self.randomise_polarisations:
|
if self.randomise_polarisations:
|
||||||
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms, device=fiber_out.device), 2) * torch.pi
|
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms, device=fiber_out.device), 2) * torch.pi
|
||||||
# start_angle = torch.rand(1) * 2 * torch.pi
|
start_angle = torch.rand(1) * 2 * torch.pi
|
||||||
# angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
|
angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
|
||||||
|
angles = torch.randn(data_raw.shape[-1], device=fiber_out.device) * 2*torch.pi / 36 # sigma = 10 degrees
|
||||||
# self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi
|
# self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi
|
||||||
else:
|
else:
|
||||||
angles = torch.zeros(data_raw.shape[-1], device=fiber_out.device)
|
angles = torch.zeros(data_raw.shape[-1], device=fiber_out.device)
|
||||||
@@ -361,8 +372,6 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
# 4 E_out_y_rot,
|
# 4 E_out_y_rot,
|
||||||
# 5 angle
|
# 5 angle
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
||||||
# data layout
|
# data layout
|
||||||
# [ [E_in_x, E_in_y, timestamps],
|
# [ [E_in_x, E_in_y, timestamps],
|
||||||
@@ -374,6 +383,9 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
self.fiber_out = fiber_out.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
self.fiber_out = fiber_out.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||||
self.fiber_out = self.fiber_out.movedim(-2, 0)
|
self.fiber_out = self.fiber_out.movedim(-2, 0)
|
||||||
|
|
||||||
|
# if self.randomise_polarisations:
|
||||||
|
# self.angles = torch.cumsum((torch.rand(self.fiber_out.shape[0]) - 0.5) * 2 * torch.pi * 2 / 5000, dim=0)
|
||||||
|
|
||||||
# self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
# self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||||
# self.data = self.data.movedim(-2, 0)
|
# self.data = self.data.movedim(-2, 0)
|
||||||
# self.angles = torch.zeros(self.data.shape[0])
|
# self.angles = torch.zeros(self.data.shape[0])
|
||||||
@@ -392,12 +404,12 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
return self.fiber_in.shape[0]
|
return self.fiber_in.shape[0]
|
||||||
|
|
||||||
def add_noise(self, data, osnr):
|
def add_noise(self, data, osnr):
|
||||||
osnr_lin = 10**(osnr/10)
|
osnr_lin = 10 ** (osnr / 10)
|
||||||
popt = torch.mean(data.abs().square().squeeze(), dim=-1)
|
popt = torch.mean(data.abs().square().squeeze(), dim=-1)
|
||||||
noise = torch.randn_like(data)
|
noise = torch.randn_like(data)
|
||||||
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
|
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
|
||||||
|
|
||||||
mult = torch.sqrt(popt/(pn*osnr_lin))
|
mult = torch.sqrt(popt / (pn * osnr_lin))
|
||||||
mult = mult * torch.eye(popt.shape[0], device=mult.device)
|
mult = mult * torch.eye(popt.shape[0], device=mult.device)
|
||||||
mult = mult.to(dtype=noise.dtype)
|
mult = mult.to(dtype=noise.dtype)
|
||||||
|
|
||||||
@@ -406,7 +418,6 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
noisy = data + noise
|
noisy = data + noise
|
||||||
return noisy
|
return noisy
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
if isinstance(idx, slice):
|
if isinstance(idx, slice):
|
||||||
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
|
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
|
||||||
@@ -418,6 +429,10 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
output_dim = self.output_dim // 2
|
output_dim = self.output_dim // 2
|
||||||
self.output_dim = output_dim * 2
|
self.output_dim = output_dim * 2
|
||||||
|
|
||||||
|
if not self.polarisations:
|
||||||
|
output_dim = 2 * output_dim
|
||||||
|
|
||||||
|
|
||||||
fiber_in = self.fiber_in[idx].squeeze()
|
fiber_in = self.fiber_in[idx].squeeze()
|
||||||
fiber_out = self.fiber_out[idx].squeeze()
|
fiber_out = self.fiber_out[idx].squeeze()
|
||||||
|
|
||||||
@@ -427,85 +442,35 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
fiber_in = fiber_in.view(fiber_in.shape[0], output_dim, -1)
|
fiber_in = fiber_in.view(fiber_in.shape[0], output_dim, -1)
|
||||||
fiber_out = fiber_out.view(fiber_out.shape[0], output_dim, -1)
|
fiber_out = fiber_out.view(fiber_out.shape[0], output_dim, -1)
|
||||||
|
|
||||||
|
center_angle = fiber_out[5, output_dim // 2, 0]
|
||||||
# data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
|
|
||||||
|
|
||||||
# data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
|
|
||||||
|
|
||||||
# angle = self.angles[idx]
|
|
||||||
|
|
||||||
# fiber_in:
|
|
||||||
# 0 E_in_x,
|
|
||||||
# 1 E_in_y,
|
|
||||||
# 2 timestamps
|
|
||||||
|
|
||||||
# fiber_out:
|
|
||||||
# 0 E_out_x,
|
|
||||||
# 1 E_out_y,
|
|
||||||
# 2 timestamps,
|
|
||||||
# 3 E_out_x_rot,
|
|
||||||
# 4 E_out_y_rot,
|
|
||||||
# 5 angle
|
|
||||||
|
|
||||||
center_angle = fiber_out[0, output_dim // 2, 0]
|
|
||||||
angles = fiber_out[5, :, 0]
|
angles = fiber_out[5, :, 0]
|
||||||
plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone()
|
|
||||||
plot_data_rot = fiber_out[3:5, output_dim // 2, 0].detach().clone()
|
plot_data_rot = fiber_out[3:5, output_dim // 2, 0].detach().clone()
|
||||||
data = fiber_out[0:2, :, 0]
|
data = fiber_out[0:2, :, 0]
|
||||||
# fiber_out_plot_clean = fiber_out[:2, output_dim // 2, 0].detach().clone()
|
plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone()
|
||||||
|
|
||||||
|
|
||||||
# if self.polarisations:
|
|
||||||
# rot = int(np.random.randint(2)*2-1)
|
|
||||||
# pol_flipped_data[0:1, :] = rot*data[0, :]
|
|
||||||
# pol_flipped_data[1, :] = rot*data[1, :]
|
|
||||||
# plot_data_rot[0] = rot*plot_data_rot[0]
|
|
||||||
# plot_data_rot[1] = rot*plot_data_rot[1]
|
|
||||||
# center_angle = center_angle + (rot - 1) * torch.pi/2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
|
|
||||||
# angles = angles + (rot - 1) * torch.pi/2
|
|
||||||
|
|
||||||
|
|
||||||
# if self.randomise_polarisations:
|
|
||||||
# data = data.mT
|
|
||||||
# c = torch.cos(angle).unsqueeze(-1)
|
|
||||||
# s = torch.sin(angle).unsqueeze(-1)
|
|
||||||
# rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1)
|
|
||||||
# data = torch.bmm(data.mT.unsqueeze(0), rot.to(dtype=data.dtype)).squeeze(-1)
|
|
||||||
...
|
|
||||||
|
|
||||||
# angle = torch.zeros_like(angle)
|
|
||||||
|
|
||||||
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter)
|
|
||||||
# angle_data = fiber_out[:2, :, :].reshape(2, -1).mean(dim=1).repeat(1, output_dim)
|
|
||||||
# angle_data2 = self.complex_max(fiber_out[:2, :, :].reshape(2, -1)).repeat(1, output_dim)
|
|
||||||
# sop = self.polarimeter(plot_data)
|
|
||||||
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
|
|
||||||
# angle = data_slice[1, 3, self.output_dim // 2, 0].real
|
|
||||||
target = fiber_in[:2, output_dim // 2, 0]
|
target = fiber_in[:2, output_dim // 2, 0]
|
||||||
plot_target = fiber_in[:2, output_dim // 2, 0].detach().clone()
|
plot_target = fiber_in[:2, output_dim // 2, 0].detach().clone()
|
||||||
target_timestamp = fiber_in[2, output_dim // 2, 0].real
|
target_timestamp = fiber_in[2, output_dim // 2, 0].real
|
||||||
...
|
...
|
||||||
|
|
||||||
if self.polarisations:
|
if self.polarisations:
|
||||||
rot = int(np.random.randint(2)*2-1)
|
rot = int(np.random.randint(2) * 2 - 1)
|
||||||
data = rot*data
|
data = rot * data
|
||||||
target = rot*target
|
target = rot * target
|
||||||
plot_data_rot = rot*plot_data_rot
|
plot_data_rot = rot * plot_data_rot
|
||||||
center_angle = center_angle + (rot - 1) * torch.pi/2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
|
center_angle = center_angle + (rot - 1) * torch.pi / 2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
|
||||||
angles = angles + (rot - 1) * torch.pi/2
|
angles = angles + (rot - 1) * torch.pi / 2
|
||||||
|
|
||||||
pol_flipped_data = -data
|
pol_flipped_data = -data
|
||||||
pol_flipped_target = -target
|
pol_flipped_target = -target
|
||||||
|
|
||||||
# data_timestamps = data[-1,:].real
|
|
||||||
# data = data[:-1, :]
|
|
||||||
# target_timestamp = target[-1].real
|
|
||||||
# target = target[:-1]
|
|
||||||
# plot_data = plot_data[:-1]
|
|
||||||
|
|
||||||
# transpose to interleave the x and y data in the output tensor
|
# transpose to interleave the x and y data in the output tensor
|
||||||
data = data.transpose(0, 1).flatten().squeeze()
|
data = data.transpose(0, 1).flatten().squeeze()
|
||||||
|
data = data / torch.sqrt(torch.ones(1) * len(data)) # power loss due to splitting
|
||||||
pol_flipped_data = pol_flipped_data.transpose(0, 1).flatten().squeeze()
|
pol_flipped_data = pol_flipped_data.transpose(0, 1).flatten().squeeze()
|
||||||
|
pol_flipped_data = pol_flipped_data / torch.sqrt(
|
||||||
|
torch.ones(1) * len(pol_flipped_data)
|
||||||
|
) # power loss due to splitting
|
||||||
# angle_data = angle_data.transpose(0, 1).flatten().squeeze()
|
# angle_data = angle_data.transpose(0, 1).flatten().squeeze()
|
||||||
# angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
|
# angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
|
||||||
center_angle = center_angle.flatten().squeeze()
|
center_angle = center_angle.flatten().squeeze()
|
||||||
@@ -526,8 +491,8 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
"y": target,
|
"y": target,
|
||||||
"y_flipped": pol_flipped_target,
|
"y_flipped": pol_flipped_target,
|
||||||
"y_stacked": torch.cat([target, pol_flipped_target], dim=-1),
|
"y_stacked": torch.cat([target, pol_flipped_target], dim=-1),
|
||||||
# "center_angle": center_angle,
|
"center_angle": center_angle,
|
||||||
# "angles": angles,
|
"angles": angles,
|
||||||
"mean_angle": angles.mean(),
|
"mean_angle": angles.mean(),
|
||||||
# "sop": sop,
|
# "sop": sop,
|
||||||
# "angle_data": angle_data,
|
# "angle_data": angle_data,
|
||||||
|
|||||||
@@ -1,16 +1,23 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
import h5py
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from matplotlib.colors import LinearSegmentedColormap
|
from matplotlib.colors import LinearSegmentedColormap
|
||||||
|
# from cmap import Colormap as cm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.cluster.vq import kmeans2
|
from scipy.cluster.vq import kmeans2
|
||||||
import warnings
|
import warnings
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from rich import pretty
|
|
||||||
from rich import print
|
|
||||||
|
|
||||||
install()
|
install()
|
||||||
pretty.install()
|
# from rich import pretty
|
||||||
|
# from rich import print
|
||||||
|
|
||||||
|
# pretty.install()
|
||||||
|
|
||||||
|
|
||||||
def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
|
def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
|
||||||
@@ -21,6 +28,7 @@ def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
|
|||||||
xaxis = np.arange(0, len(signal)) / sps
|
xaxis = np.arange(0, len(signal)) / sps
|
||||||
return np.vstack([xaxis, signal])
|
return np.vstack([xaxis, signal])
|
||||||
|
|
||||||
|
|
||||||
def create_symbol_sequence(n_symbols, skew=1):
|
def create_symbol_sequence(n_symbols, skew=1):
|
||||||
np.random.seed(42)
|
np.random.seed(42)
|
||||||
data = np.random.randint(0, 4, n_symbols) / 4
|
data = np.random.randint(0, 4, n_symbols) / 4
|
||||||
@@ -39,6 +47,14 @@ def generate_signal(data, sps):
|
|||||||
signal = np.convolve(data_padded, wavelet)
|
signal = np.convolve(data_padded, wavelet)
|
||||||
signal = np.cumsum(signal)
|
signal = np.cumsum(signal)
|
||||||
signal = signal[sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
signal = signal[sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
||||||
|
mi, ma = np.min(signal), np.max(signal)
|
||||||
|
|
||||||
|
signal = (signal - mi) / (ma - mi)
|
||||||
|
|
||||||
|
mod = 0.8
|
||||||
|
|
||||||
|
signal *= mod
|
||||||
|
signal += 1 - mod
|
||||||
|
|
||||||
return signal
|
return signal
|
||||||
|
|
||||||
@@ -49,8 +65,8 @@ def normalization_with_noise(signal, noise=0):
|
|||||||
signal += awgn
|
signal += awgn
|
||||||
|
|
||||||
# min-max normalization
|
# min-max normalization
|
||||||
signal = signal - np.min(signal)
|
# signal = signal - np.min(signal)
|
||||||
signal = signal / np.max(signal)
|
# signal = signal / np.max(signal)
|
||||||
return signal
|
return signal
|
||||||
|
|
||||||
|
|
||||||
@@ -68,26 +84,132 @@ def generate_wavelet(sps, oversample=3):
|
|||||||
|
|
||||||
|
|
||||||
class eye_diagram:
|
class eye_diagram:
|
||||||
def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4, multithreaded=True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
data,
|
||||||
|
*,
|
||||||
|
channel_names=None,
|
||||||
|
horizontal_bins=256,
|
||||||
|
vertical_bins=1000,
|
||||||
|
n_levels=4,
|
||||||
|
multithreaded=True,
|
||||||
|
save_file_or_dir=None,
|
||||||
|
):
|
||||||
# data has shape [channels, 2, samples]
|
# data has shape [channels, 2, samples]
|
||||||
# each sample has a timestamp and a value
|
# each sample has a timestamp and a value
|
||||||
if data.ndim == 2:
|
if data.ndim == 2:
|
||||||
data = data[np.newaxis, :, :]
|
data = data[np.newaxis, :, :]
|
||||||
self.channel_names = channel_names
|
|
||||||
self.raw_data = data
|
self.raw_data = data
|
||||||
self.channels = data.shape[0]
|
|
||||||
|
self.y_bins = np.zeros(1)
|
||||||
|
self.x_bins = np.zeros(1)
|
||||||
|
self.eye_data = np.zeros(1)
|
||||||
|
self.channel_names = channel_names
|
||||||
|
self.n_channels = data.shape[0]
|
||||||
self.n_levels = n_levels
|
self.n_levels = n_levels
|
||||||
self.eye_stats = [{"success": False} for _ in range(self.channels)]
|
self.eye_stats = [{"success": False} for _ in range(self.n_channels)]
|
||||||
self.horizontal_bins = horizontal_bins
|
self.horizontal_bins = horizontal_bins
|
||||||
self.vertical_bins = vertical_bins
|
self.vertical_bins = vertical_bins
|
||||||
self.multi_threaded = multithreaded
|
self.multi_threaded = multithreaded
|
||||||
|
self.analysed = False
|
||||||
self.eye_built = False
|
self.eye_built = False
|
||||||
|
|
||||||
def generate_eye_data(self):
|
self.save_file = save_file_or_dir
|
||||||
|
|
||||||
|
def load_data(self, file=None):
|
||||||
|
file = self.save_file if file is None else file
|
||||||
|
|
||||||
|
if file is None:
|
||||||
|
raise FileNotFoundError("No file specified.")
|
||||||
|
|
||||||
|
self.save_file = str(file)
|
||||||
|
# self.file_or_dir = self.save_file
|
||||||
|
with h5py.File(file, "r") as infile:
|
||||||
|
self.y_bins = infile["y_bins"][:]
|
||||||
|
self.x_bins = infile["x_bins"][:]
|
||||||
|
self.eye_data = infile["eye_data"][:]
|
||||||
|
self.channel_names = infile.attrs["channel_names"]
|
||||||
|
self.n_channels = infile.attrs["n_channels"]
|
||||||
|
self.n_levels = infile.attrs["n_levels"]
|
||||||
|
self.eye_stats = infile.attrs["eye_stats"]
|
||||||
|
self.eye_stats = [json.loads(stat) for stat in self.eye_stats]
|
||||||
|
self.horizontal_bins = infile.attrs["horizontal_bins"]
|
||||||
|
self.vertical_bins = infile.attrs["vertical_bins"]
|
||||||
|
self.multi_threaded = infile.attrs["multithreaded"]
|
||||||
|
self.analysed = infile.attrs["analysed"]
|
||||||
|
self.eye_built = infile.attrs["eye_built"]
|
||||||
|
|
||||||
|
def save_data(self, file_or_dir=None):
|
||||||
|
file_or_dir = self.save_file if file_or_dir is None else file_or_dir
|
||||||
|
if file_or_dir is None:
|
||||||
|
file = Path(f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5")
|
||||||
|
elif Path(file_or_dir).is_dir():
|
||||||
|
file = Path(file_or_dir) / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5"
|
||||||
|
else:
|
||||||
|
file = Path(file_or_dir)
|
||||||
|
|
||||||
|
# file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.save_file = str(file)
|
||||||
|
|
||||||
|
with h5py.File(file, "w") as outfile:
|
||||||
|
outfile.create_dataset("eye_data", data=self.eye_data)
|
||||||
|
outfile.create_dataset("y_bins", data=self.y_bins)
|
||||||
|
outfile.create_dataset("x_bins", data=self.x_bins)
|
||||||
|
outfile.attrs["channel_names"] = self.channel_names
|
||||||
|
outfile.attrs["n_channels"] = self.n_channels
|
||||||
|
outfile.attrs["n_levels"] = self.n_levels
|
||||||
|
self.eye_stats = eye_diagram.convert_arrays(self.eye_stats)
|
||||||
|
outfile.attrs["eye_stats"] = [json.dumps(stat) for stat in self.eye_stats]
|
||||||
|
outfile.attrs["horizontal_bins"] = self.horizontal_bins
|
||||||
|
outfile.attrs["vertical_bins"] = self.vertical_bins
|
||||||
|
outfile.attrs["multithreaded"] = self.multi_threaded
|
||||||
|
outfile.attrs["analysed"] = self.analysed
|
||||||
|
outfile.attrs["eye_built"] = self.eye_built
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_arrays(input_object):
|
||||||
|
"""
|
||||||
|
convert ndarrays in (nested) dict to lists
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(input_object, np.ndarray):
|
||||||
|
return input_object.tolist()
|
||||||
|
elif isinstance(input_object, list):
|
||||||
|
return [eye_diagram.convert_arrays(old) for old in input_object]
|
||||||
|
elif isinstance(input_object, tuple):
|
||||||
|
return tuple(eye_diagram.convert_arrays(old) for old in input_object)
|
||||||
|
elif isinstance(input_object, dict):
|
||||||
|
dict_out = {}
|
||||||
|
for key, value in input_object.items():
|
||||||
|
dict_out[key] = eye_diagram.convert_arrays(value)
|
||||||
|
return dict_out
|
||||||
|
return input_object
|
||||||
|
|
||||||
|
def generate_eye_data(
|
||||||
|
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
|
||||||
|
):
|
||||||
|
# modes:
|
||||||
|
# default: try to load eye data from file, if not found, generate and save
|
||||||
|
# load: try to load eye data from file, if not found, generate but don't save
|
||||||
|
# save: generate eye data and save
|
||||||
|
update_save = True
|
||||||
|
if mode == "load":
|
||||||
|
self.load_data(file_or_dir)
|
||||||
|
update_save = False
|
||||||
|
elif mode == "default":
|
||||||
|
try:
|
||||||
|
self.load_data(file_or_dir)
|
||||||
|
update_save = False
|
||||||
|
except (FileNotFoundError, IsADirectoryError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not self.eye_built:
|
||||||
|
update_save = True
|
||||||
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
|
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
|
||||||
self.y_bins = np.zeros((self.channels, self.vertical_bins))
|
self.y_bins = np.zeros((self.n_channels, self.vertical_bins))
|
||||||
self.eye_data = np.zeros((self.channels, self.vertical_bins, self.horizontal_bins))
|
self.eye_data = np.zeros((self.n_channels, self.vertical_bins, self.horizontal_bins))
|
||||||
datas = [self.raw_data[i] for i in range(self.channels)]
|
datas = [self.raw_data[i] for i in range(self.n_channels)]
|
||||||
if self.multi_threaded:
|
if self.multi_threaded:
|
||||||
with multiprocessing.Pool() as pool:
|
with multiprocessing.Pool() as pool:
|
||||||
results = pool.map(self.generate_eye_data_single, datas)
|
results = pool.map(self.generate_eye_data_single, datas)
|
||||||
@@ -98,54 +220,112 @@ class eye_diagram:
|
|||||||
self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
|
self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
|
||||||
self.eye_built = True
|
self.eye_built = True
|
||||||
|
|
||||||
|
if mode == "save" or (mode == "default" and update_save):
|
||||||
|
self.save_data(file_or_dir)
|
||||||
|
|
||||||
def generate_eye_data_single(self, data):
|
def generate_eye_data_single(self, data):
|
||||||
eye_data = np.zeros((self.vertical_bins, self.horizontal_bins))
|
eye_data = np.zeros((self.vertical_bins, self.horizontal_bins))
|
||||||
data_min = np.min(data[1, :])
|
data_min = np.min(data[1, :])
|
||||||
data_max = np.max(data[1, :])
|
data_max = np.max(data[1, :])
|
||||||
|
# round down/up to 1 decimal
|
||||||
|
data_min = np.floor(data_min*10)/10
|
||||||
|
data_max = np.ceil(data_max*10)/10
|
||||||
|
# data_range = data_max - data_min
|
||||||
|
# data_min -= 0.1 * data_range
|
||||||
|
# data_max += 0.1 * data_range
|
||||||
|
# data_min = -0.05
|
||||||
|
# data_max += 0.05
|
||||||
|
# data[1,:] -= np.min(data[1, :])
|
||||||
|
# data[1,:] /= np.max(data[1, :])
|
||||||
|
# data_min = 0
|
||||||
|
# data_max = 1
|
||||||
y_bins = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
|
y_bins = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
|
||||||
t_vals = data[0, :] % 2
|
t_vals = data[0, :] % 2 # + np.random.randn(*data[0, :].shape) * 1 / (512)
|
||||||
val_vals = data[1, :]
|
val_vals = data[1, :] # + np.random.randn(*data[1, :].shape) * 1 / (320)
|
||||||
x_indices = np.digitize(t_vals, self.x_bins) - 1
|
x_indices = np.digitize(t_vals, self.x_bins) - 1
|
||||||
y_indices = np.digitize(val_vals, y_bins) - 1
|
y_indices = np.digitize(val_vals, y_bins) - 1
|
||||||
np.add.at(eye_data, (y_indices, x_indices), 1)
|
np.add.at(eye_data, (y_indices, x_indices), 1)
|
||||||
return eye_data, y_bins
|
return eye_data, y_bins
|
||||||
|
|
||||||
def plot(self, title="Eye Diagram", stats=True, all_stats=True, show=True):
|
def plot(
|
||||||
|
self,
|
||||||
|
title="Eye Diagram",
|
||||||
|
stats=True,
|
||||||
|
all_stats=True,
|
||||||
|
show=True,
|
||||||
|
mode: Literal["default", "load", "save", "nosave"] = "default",
|
||||||
|
# save_images = False,
|
||||||
|
# image_dir = None,
|
||||||
|
# cmap=None,
|
||||||
|
):
|
||||||
|
if stats and not self.analysed:
|
||||||
|
self.analyse(mode=mode)
|
||||||
if not self.eye_built:
|
if not self.eye_built:
|
||||||
self.generate_eye_data()
|
self.generate_eye_data(mode=mode)
|
||||||
cmap = LinearSegmentedColormap.from_list(
|
cmap = LinearSegmentedColormap.from_list(
|
||||||
"eyemap",
|
"eyemap",
|
||||||
[(0, "white"), (0.1, "blue"), (0.2, "cyan"), (0.5, "green"), (0.8, "yellow"), (0.9, "red"), (1, "magenta")],
|
[
|
||||||
|
(0, "#FFFFFF00"),
|
||||||
|
(0.1, "blue"),
|
||||||
|
(0.2, "cyan"),
|
||||||
|
(0.5, "green"),
|
||||||
|
(0.8, "yellow"),
|
||||||
|
(0.9, "red"),
|
||||||
|
(1, "magenta"),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
if self.channels % 2 == 0:
|
# cmap = cm('google:turbo_r' if cmap is None else cmap)
|
||||||
|
# first = cmap(-1)
|
||||||
|
# cmap = cmap.to_mpl()
|
||||||
|
# cmap.set_under(first, alpha=0)
|
||||||
|
if self.n_channels % 2 == 0:
|
||||||
rows = 2
|
rows = 2
|
||||||
cols = self.channels // 2
|
cols = self.n_channels // 2
|
||||||
else:
|
else:
|
||||||
cols = int(np.ceil(np.sqrt(self.channels)))
|
cols = int(np.ceil(np.sqrt(self.n_channels)))
|
||||||
rows = int(np.ceil(self.channels / cols))
|
rows = int(np.ceil(self.n_channels / cols))
|
||||||
fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False)
|
fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False)
|
||||||
fig.suptitle(title)
|
fig.suptitle(title)
|
||||||
fig.tight_layout()
|
fig.tight_layout()
|
||||||
ax = np.atleast_1d(ax).transpose().flatten()
|
ax = np.atleast_1d(ax).transpose().flatten()
|
||||||
for i in range(self.channels):
|
for i in range(self.n_channels):
|
||||||
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}")
|
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i + 1}")
|
||||||
if (i+1) % rows == 0:
|
if (i + 1) % rows == 0:
|
||||||
ax[i].set_xlabel("Symbol")
|
ax[i].set_xlabel("Symbol")
|
||||||
if i < rows:
|
if i < rows:
|
||||||
ax[i].set_ylabel("Amplitude")
|
ax[i].set_ylabel("Amplitude")
|
||||||
ax[i].grid()
|
ax[i].grid()
|
||||||
|
ax[i].set_axisbelow(True)
|
||||||
ax[i].imshow(
|
ax[i].imshow(
|
||||||
self.eye_data[i],
|
self.eye_data[i] - 0.1,
|
||||||
origin="lower",
|
origin="lower",
|
||||||
aspect="auto",
|
aspect="auto",
|
||||||
cmap=cmap,
|
cmap=cmap,
|
||||||
extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
|
extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
|
||||||
|
interpolation="gaussian",
|
||||||
|
vmin=0,
|
||||||
|
zorder=3,
|
||||||
)
|
)
|
||||||
ax[i].set_xlim((self.x_bins[0], self.x_bins[-1]))
|
ax[i].set_xlim((self.x_bins[0], self.x_bins[-1]))
|
||||||
ymin = np.min(self.y_bins[:, 0])
|
ymin = np.min(self.y_bins[:, 0])
|
||||||
ymax = np.max(self.y_bins[:, -1])
|
ymax = np.max(self.y_bins[:, -1])
|
||||||
yspan = ymax - ymin
|
yspan = ymax - ymin
|
||||||
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
|
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
|
||||||
|
# if save_images:
|
||||||
|
# image_dir = "images_out" if image_dir is None else image_dir
|
||||||
|
# image_path = Path(image_dir) / (slugify(f"{datetime.now().strftime("%Y%m%d_%H%M%S")}_{title.replace(" ","_")}_{self.channel_names[i].replace(" ", "_") if self.channel_names is not None else f"{i + 1}"}_{ymin:.1f}_{ymax:.1f}") + ".png")
|
||||||
|
# image_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
# # plt.imsave(
|
||||||
|
# # image_path,
|
||||||
|
# # self.eye_data[i] - 0.1,
|
||||||
|
# # origin="lower",
|
||||||
|
# # # aspect="auto",
|
||||||
|
# # cmap=cmap,
|
||||||
|
# # # extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
|
||||||
|
# # # interpolation="gaussian",
|
||||||
|
# # vmin=0,
|
||||||
|
# # # zorder=3,
|
||||||
|
# # )
|
||||||
if stats and self.eye_stats[i]["success"]:
|
if stats and self.eye_stats[i]["success"]:
|
||||||
# # add min_area above the plot
|
# # add min_area above the plot
|
||||||
# ax[i].annotate(
|
# ax[i].annotate(
|
||||||
@@ -159,7 +339,7 @@ class eye_diagram:
|
|||||||
|
|
||||||
if all_stats:
|
if all_stats:
|
||||||
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
|
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
|
||||||
y_ticks = (*self.eye_stats[i]["levels"],*self.eye_stats[i]["thresholds"])
|
y_ticks = (*self.eye_stats[i]["levels"], *self.eye_stats[i]["thresholds"])
|
||||||
# y_ticks = np.sort(y_ticks)
|
# y_ticks = np.sort(y_ticks)
|
||||||
ax[i].set_yticks(y_ticks)
|
ax[i].set_yticks(y_ticks)
|
||||||
# add arrows for amplitudes
|
# add arrows for amplitudes
|
||||||
@@ -235,19 +415,19 @@ class eye_diagram:
|
|||||||
def calculate_thresholds(levels):
|
def calculate_thresholds(levels):
|
||||||
ret = np.cumsum(levels, dtype=float)
|
ret = np.cumsum(levels, dtype=float)
|
||||||
ret[2:] = ret[2:] - ret[:-2]
|
ret[2:] = ret[2:] - ret[:-2]
|
||||||
return ret[1:]/2
|
return ret[1:] / 2
|
||||||
|
|
||||||
def analyse_single(self, data, index):
|
def analyse_single(self, data, index):
|
||||||
warnings.filterwarnings("error")
|
warnings.filterwarnings("error")
|
||||||
eye_stats = {}
|
eye_stats = {}
|
||||||
eye_stats["channel_name"] = str(index+1) if self.channel_names is None else self.channel_names[index]
|
eye_stats["channel_name"] = str(index + 1) if self.channel_names is None else self.channel_names[index]
|
||||||
try:
|
try:
|
||||||
approx_levels = eye_diagram.approximate_levels(data, self.n_levels)
|
approx_levels = eye_diagram.approximate_levels(data, self.n_levels)
|
||||||
|
|
||||||
time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels)
|
time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels)
|
||||||
|
|
||||||
eye_stats["time_midpoint_calc"] = (time_bounds[0] + time_bounds[1]) / 2
|
eye_stats["time_midpoint"] = float((time_bounds[0] + time_bounds[1]) / 2)
|
||||||
eye_stats["time_midpoint"] = 1.0
|
# eye_stats["time_midpoint"] = 1.0
|
||||||
|
|
||||||
eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels(
|
eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels(
|
||||||
data, approx_levels, time_bounds
|
data, approx_levels, time_bounds
|
||||||
@@ -257,9 +437,7 @@ class eye_diagram:
|
|||||||
|
|
||||||
eye_stats["amplitudes"] = np.diff(eye_stats["levels"])
|
eye_stats["amplitudes"] = np.diff(eye_stats["levels"])
|
||||||
|
|
||||||
eye_stats["heights"] = eye_diagram.calculate_eye_heights(
|
eye_stats["heights"] = eye_diagram.calculate_eye_heights(eye_stats["amplitude_clusters"])
|
||||||
eye_stats["amplitude_clusters"]
|
|
||||||
)
|
|
||||||
|
|
||||||
eye_stats["widths"], eye_stats["time_clusters"] = eye_diagram.calculate_eye_widths(
|
eye_stats["widths"], eye_stats["time_clusters"] = eye_diagram.calculate_eye_widths(
|
||||||
data, eye_stats["levels"]
|
data, eye_stats["levels"]
|
||||||
@@ -291,17 +469,39 @@ class eye_diagram:
|
|||||||
warnings.resetwarnings()
|
warnings.resetwarnings()
|
||||||
return eye_stats
|
return eye_stats
|
||||||
|
|
||||||
|
def analyse(
|
||||||
|
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
|
||||||
|
):
|
||||||
|
# modes:
|
||||||
|
# default: try to load eye data from file, if not found, generate and save
|
||||||
|
# load: try to load eye data from file, if not found, generate but don't save
|
||||||
|
# save: generate eye data and save
|
||||||
|
update_save = True
|
||||||
|
if mode == "load":
|
||||||
|
self.load_data(file_or_dir)
|
||||||
|
update_save = False
|
||||||
|
elif mode == "default":
|
||||||
|
try:
|
||||||
|
self.load_data(file_or_dir)
|
||||||
|
update_save = False
|
||||||
|
except (FileNotFoundError, IsADirectoryError):
|
||||||
|
pass
|
||||||
|
|
||||||
def analyse(self):
|
if not self.analysed:
|
||||||
|
update_save = True
|
||||||
self.eye_stats = []
|
self.eye_stats = []
|
||||||
if self.multi_threaded:
|
if self.multi_threaded:
|
||||||
with multiprocessing.Pool() as pool:
|
with multiprocessing.Pool() as pool:
|
||||||
results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.channels)])
|
results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.n_channels)])
|
||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
self.eye_stats.append(result)
|
self.eye_stats.append(result)
|
||||||
else:
|
else:
|
||||||
for i in range(self.channels):
|
for i in range(self.n_channels):
|
||||||
self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
|
self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
|
||||||
|
self.analysed = True
|
||||||
|
|
||||||
|
if mode == "save" or (mode == "default" and update_save):
|
||||||
|
self.save_data(file_or_dir)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def approximate_levels(data, levels):
|
def approximate_levels(data, levels):
|
||||||
@@ -443,7 +643,7 @@ class eye_diagram:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
length = int(2**14)
|
length = int(2**16)
|
||||||
# data = generate_sample_data(length, noise=1)
|
# data = generate_sample_data(length, noise=1)
|
||||||
# data1 = generate_sample_data(length, noise=0.01)
|
# data1 = generate_sample_data(length, noise=0.01)
|
||||||
# data2 = generate_sample_data(length, noise=0.01, skew=1.2)
|
# data2 = generate_sample_data(length, noise=0.01, skew=1.2)
|
||||||
@@ -451,13 +651,13 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# data = np.stack([data, data1, data2, data3])
|
# data = np.stack([data, data1, data2, data3])
|
||||||
|
|
||||||
data = generate_sample_data(length, noise=0.005)
|
data = generate_sample_data(length, noise=0.0000)
|
||||||
eye = eye_diagram(data, horizontal_bins=256, vertical_bins=256)
|
eye = eye_diagram(data, horizontal_bins=1056, vertical_bins=1200)
|
||||||
eye.analyse()
|
eye.plot(mode="nosave", stats=False)
|
||||||
attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area")
|
# attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area")
|
||||||
for i, channel in enumerate(eye.eye_stats):
|
# for i, channel in enumerate(eye.eye_stats):
|
||||||
print(f"Channel {i}")
|
# print(f"Channel {i}")
|
||||||
print_data = {attr: channel[attr] for attr in attrs}
|
# print_data = {attr: channel[attr] for attr in attrs}
|
||||||
print(print_data)
|
# print(print_data)
|
||||||
|
|
||||||
eye.plot()
|
# eye.plot()
|
||||||
|
|||||||
122
src/single-core-regen/util/mpl.py
Normal file
122
src/single-core-regen/util/mpl.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
|
||||||
|
# This software is licensed according to the "BSD 2-clause" license.
|
||||||
|
|
||||||
|
# modified by Joseph Hopfmüller in 2025,
|
||||||
|
# for integration into optical regeneration analysis scripts
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from matplotlib.colors import LinearSegmentedColormap
|
||||||
|
import matplotlib.colors as colors
|
||||||
|
import numpy as _np
|
||||||
|
from .core import grid_count as _grid_count
|
||||||
|
import matplotlib.pyplot as _plt
|
||||||
|
import numpy as np
|
||||||
|
from scipy.ndimage import gaussian_filter
|
||||||
|
|
||||||
|
|
||||||
|
# from ._common import _common_doc
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["eyediagram"] # , 'eyediagram_lines']
|
||||||
|
|
||||||
|
|
||||||
|
# def eyediagram_lines(y, window_size, offset=0, **plotkwargs):
|
||||||
|
# """
|
||||||
|
# Plot an eye diagram using matplotlib by repeatedly calling the `plot`
|
||||||
|
# function.
|
||||||
|
# <common>
|
||||||
|
|
||||||
|
# """
|
||||||
|
# start = offset
|
||||||
|
# while start < len(y):
|
||||||
|
# end = start + window_size
|
||||||
|
# if end > len(y):
|
||||||
|
# end = len(y)
|
||||||
|
# yy = y[start:end+1]
|
||||||
|
# _plt.plot(_np.arange(len(yy)), yy, 'k', **plotkwargs)
|
||||||
|
# start = end
|
||||||
|
|
||||||
|
# eyediagram_lines.__doc__ = eyediagram_lines.__doc__.replace("<common>",
|
||||||
|
# _common_doc)
|
||||||
|
|
||||||
|
|
||||||
|
eyemap = LinearSegmentedColormap.from_list(
|
||||||
|
"eyemap",
|
||||||
|
[
|
||||||
|
(0, "#0000FF00"),
|
||||||
|
(0.1, "blue"),
|
||||||
|
(0.2, "cyan"),
|
||||||
|
(0.5, "green"),
|
||||||
|
(0.8, "yellow"),
|
||||||
|
(0.9, "red"),
|
||||||
|
(1, "magenta"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def eyediagram(
|
||||||
|
y,
|
||||||
|
window_size,
|
||||||
|
offset=0,
|
||||||
|
colorbar=False,
|
||||||
|
show=False,
|
||||||
|
save_im=False,
|
||||||
|
overwrite=False,
|
||||||
|
blur: int | bool = True,
|
||||||
|
save_path="out.png",
|
||||||
|
bounds=None,
|
||||||
|
**imshowkwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Plot an eye diagram using matplotlib by creating an image and calling
|
||||||
|
the `imshow` function.
|
||||||
|
<common>
|
||||||
|
"""
|
||||||
|
if bounds is None:
|
||||||
|
ymax = y.max()
|
||||||
|
ymin = y.min()
|
||||||
|
yamp = ymax - ymin
|
||||||
|
ymin = ymin - 0.05 * yamp
|
||||||
|
ymax = ymax + 0.05 * yamp
|
||||||
|
ymin = np.floor(ymin * 10) / 10
|
||||||
|
ymax = np.ceil(ymax * 10) / 10
|
||||||
|
bounds = (ymin, ymax)
|
||||||
|
counts = _grid_count(y, window_size, offset, bounds=bounds, size=(1000, 1200), blur=int(blur))
|
||||||
|
counts = counts.astype(_np.float32)
|
||||||
|
origin = imshowkwargs.pop("origin", "lower")
|
||||||
|
cmap: colors.Colormap = imshowkwargs.pop("cmap", eyemap)
|
||||||
|
vmin = imshowkwargs.pop("vmin", 1)
|
||||||
|
vmax = imshowkwargs.pop("vmax", None)
|
||||||
|
cmap.set_under("white", alpha=0)
|
||||||
|
|
||||||
|
if show:
|
||||||
|
_plt.imshow(
|
||||||
|
counts.T[::-1, :],
|
||||||
|
extent=[0, 2, *bounds],
|
||||||
|
origin=origin,
|
||||||
|
cmap=cmap,
|
||||||
|
vmin=vmin,
|
||||||
|
vmax=vmax,
|
||||||
|
**imshowkwargs,
|
||||||
|
)
|
||||||
|
_plt.grid()
|
||||||
|
if colorbar:
|
||||||
|
_plt.colorbar()
|
||||||
|
|
||||||
|
if Path(save_path).is_file() and not overwrite:
|
||||||
|
save_im = False
|
||||||
|
if save_im:
|
||||||
|
from PIL import Image
|
||||||
|
arr = counts.T[::-1, :]
|
||||||
|
if origin == "lower":
|
||||||
|
arr = arr[::-1]
|
||||||
|
arr = (arr-arr.min())/(arr.max()-arr.min())
|
||||||
|
image = Image.fromarray((cmap(arr)[:, :, :] * 255).astype(np.uint8))
|
||||||
|
image.save(save_path)
|
||||||
|
# print("-")
|
||||||
|
|
||||||
|
if show:
|
||||||
|
_plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
# eyediagram.__doc__ = eyediagram.__doc__.replace("<common>", _common_doc)
|
||||||
Reference in New Issue
Block a user