wip
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -163,4 +163,5 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
tolerance_results/datasets/*
|
||||
tolerance_results/*
|
||||
data/*
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fcbdaffa211d6b0b44b3ae1c66645999e95901bfdb2fffee4c45e34a0d901ee1
|
||||
size 649
|
||||
3
data/npys/6789fdea2609799ef2e975907625b79a.h5
Normal file
3
data/npys/6789fdea2609799ef2e975907625b79a.h5
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1df90745cc2e6d4b0ad964fca2de1441e6e0b4b8345fbb0fbc1ffe9820674269
|
||||
size 134481920
|
||||
59
notes/tolerance_testing.md
Normal file
59
notes/tolerance_testing.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# Baseline Models
|
||||
|
||||
## a) D+S, pol_error 0, ortho_error 0, DGD 0
|
||||
|
||||
dataset
|
||||
|
||||
```raw
|
||||
data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||
```
|
||||
|
||||
model
|
||||
|
||||
```raw
|
||||
.models/best_20250118_225918.tar
|
||||
```
|
||||
|
||||
## b) D+S, pol_error 0.4, ortho_error 0, DGD 0
|
||||
|
||||
dataset
|
||||
|
||||
```raw
|
||||
data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||
```
|
||||
|
||||
model
|
||||
|
||||
```raw
|
||||
.models/best_20250116_214816.tar
|
||||
```
|
||||
|
||||
## c) D+S, pol_error 0, ortho_error 0.1, DGD 0
|
||||
|
||||
dataset
|
||||
|
||||
```raw
|
||||
data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||
```
|
||||
|
||||
model
|
||||
|
||||
```raw
|
||||
.models/best_20250117_122319.tar
|
||||
```
|
||||
|
||||
## d) D+S, pol_error 0, ortho_error 0, DGD 10ps (1 T_sym)
|
||||
|
||||
birefringence angle pi/2 (worst case)
|
||||
|
||||
dataset
|
||||
|
||||
```raw
|
||||
data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||
```
|
||||
|
||||
model
|
||||
|
||||
```raw
|
||||
.models/best_20250117_144001.tar
|
||||
```
|
||||
2
pypho
2
pypho
Submodule pypho updated: dd015f4852...e44fc477fe
@@ -164,10 +164,14 @@ class regenerator(Module):
|
||||
module = act_function(size=dims[-1], **act_func_kwargs)
|
||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
|
||||
|
||||
module = Scale(size=dims[-1])
|
||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
|
||||
|
||||
if self.rotation:
|
||||
module = rotate()
|
||||
self.add_module("rotate", module)
|
||||
|
||||
|
||||
# module = Scale(size=dims[-1])
|
||||
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
|
||||
|
||||
|
||||
@@ -18,9 +18,11 @@ class DataSettings:
|
||||
shuffle: bool = True
|
||||
in_out_delay: float = 0
|
||||
xy_delay: tuple | float | int = 0
|
||||
drop_first: int = 1000
|
||||
drop_first: int = 64
|
||||
drop_last: int = 64
|
||||
train_split: float = 0.8
|
||||
polarisations: tuple | list = (0,)
|
||||
# cross_pol_interference: float = 0
|
||||
randomise_polarisations: bool = False
|
||||
osnr: float | int = None
|
||||
seed: int = None
|
||||
@@ -93,6 +95,12 @@ class ModelSettings:
|
||||
"""
|
||||
|
||||
|
||||
def _early_stop_default_kwargs():
|
||||
return {
|
||||
"threshold": 1e-05,
|
||||
"plateau": 25,
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class OptimizerSettings:
|
||||
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
|
||||
@@ -101,6 +109,9 @@ class OptimizerSettings:
|
||||
scheduler: str | None = None
|
||||
scheduler_kwargs: dict | None = None
|
||||
|
||||
early_stopping: bool = False
|
||||
early_stop_kwargs: dict = field(default_factory=_early_stop_default_kwargs)
|
||||
|
||||
"""
|
||||
change to:
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
import random
|
||||
import matplotlib
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
import torch.nn.utils.parametrize
|
||||
|
||||
try:
|
||||
@@ -46,13 +47,72 @@ from .settings import (
|
||||
PytorchSettings,
|
||||
)
|
||||
|
||||
from cmcrameri import cm
|
||||
# from matplotlib import colors as mcolors
|
||||
# alpha_map = mcolors.LinearSegmentedColormap(
|
||||
# 'alphamap',
|
||||
# {
|
||||
# 'red': [(0, 0, 0), (1, 0, 0)],
|
||||
# 'green': [(0, 0, 0), (1, 0, 0)],
|
||||
# 'blue': [(0, 0, 0), (1, 0, 0)],
|
||||
# 'alpha': [
|
||||
# (0, 1, 1),
|
||||
# # (0.2, 0.2, 0.1),
|
||||
# (1, 0, 0)
|
||||
# ]
|
||||
# }
|
||||
# )
|
||||
# alpha_map.set_bad(color="#AAAAAA")
|
||||
|
||||
def pad_to_size(array, size):
|
||||
if not hasattr(size, "__len__"):
|
||||
size = (size, size)
|
||||
|
||||
left = (
|
||||
(size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0
|
||||
)
|
||||
right = (
|
||||
(size[0] - array.shape[0]) // 2 if size[0] is not None else 0
|
||||
)
|
||||
top = (
|
||||
(size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0
|
||||
)
|
||||
bottom = (
|
||||
(size[1] - array.shape[1]) // 2 if size[1] is not None else 0
|
||||
)
|
||||
|
||||
array: np.ndarray = array
|
||||
if array.ndim == 2:
|
||||
return np.pad(
|
||||
array,
|
||||
(
|
||||
(left, right),
|
||||
(top, bottom),
|
||||
),
|
||||
constant_values=(np.nan, np.nan),
|
||||
)
|
||||
elif array.ndim == 3:
|
||||
return np.pad(
|
||||
array,
|
||||
(
|
||||
(left, right),
|
||||
(top, bottom),
|
||||
(0,0)
|
||||
),
|
||||
constant_values=(np.nan, np.nan),
|
||||
)
|
||||
|
||||
def traverse_dict_update(target, source):
|
||||
for k, v in source.items():
|
||||
if isinstance(v, dict):
|
||||
if k not in target:
|
||||
target[k] = {}
|
||||
traverse_dict_update(target[k], v)
|
||||
try:
|
||||
if k not in target:
|
||||
target[k] = {}
|
||||
traverse_dict_update(target[k], v)
|
||||
except TypeError:
|
||||
if k not in target.__dict__:
|
||||
setattr(target, k, {})
|
||||
traverse_dict_update(target.__dict__[k], v)
|
||||
else:
|
||||
try:
|
||||
target[k] = v
|
||||
@@ -261,6 +321,7 @@ class PolarizationTrainer:
|
||||
target_delay=in_out_delay,
|
||||
xy_delay=xy_delay,
|
||||
drop_first=self.data_settings.drop_first,
|
||||
drop_last=self.data_settings.drop_last,
|
||||
dtype=dtype,
|
||||
real=not dtype.is_complex,
|
||||
num_symbols=num_symbols,
|
||||
@@ -602,6 +663,7 @@ class RegenerationTrainer:
|
||||
console=None,
|
||||
checkpoint_path=None,
|
||||
settings_override=None,
|
||||
new_model=False,
|
||||
reset_epoch=False,
|
||||
):
|
||||
self.resume = checkpoint_path is not None
|
||||
@@ -615,12 +677,23 @@ class RegenerationTrainer:
|
||||
models.regenerator,
|
||||
torch.nn.utils.parametrizations.orthogonal,
|
||||
])
|
||||
# self.new_model = True
|
||||
self.model_name = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
if self.resume:
|
||||
print(f"loading checkpoint from {checkpoint_path}")
|
||||
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
|
||||
if settings_override is not None:
|
||||
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
|
||||
if reset_epoch:
|
||||
|
||||
if not new_model:
|
||||
# self.new_model = False
|
||||
checkpoint_file = checkpoint_path.split("/")[-1].split(".")[0]
|
||||
if checkpoint_file.startswith("best"):
|
||||
self.model_name = "_".join(checkpoint_file.split("_")[1:])
|
||||
else:
|
||||
self.model_name = "_".join(checkpoint_file.split("_")[:-1])
|
||||
|
||||
if new_model or reset_epoch:
|
||||
self.checkpoint_dict["epoch"] = -1
|
||||
|
||||
self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"]
|
||||
@@ -654,7 +727,7 @@ class RegenerationTrainer:
|
||||
self.writer = None
|
||||
|
||||
def setup_tb_writer(self, append=None):
|
||||
log_dir = self.pytorch_settings.summary_dir + "/" + (datetime.now().strftime("%Y%m%d_%H%M%S"))
|
||||
log_dir = self.pytorch_settings.summary_dir + "/" + self.model_name
|
||||
if append is not None:
|
||||
log_dir += "_" + str(append)
|
||||
|
||||
@@ -697,8 +770,8 @@ class RegenerationTrainer:
|
||||
|
||||
output_dim = self.model_settings.output_dim
|
||||
|
||||
# if self.data_settings.polarisations:
|
||||
output_dim *= 2
|
||||
if self.data_settings.polarisations:
|
||||
output_dim *= 2
|
||||
|
||||
dtype = getattr(torch, self.data_settings.dtype)
|
||||
|
||||
@@ -755,11 +828,13 @@ class RegenerationTrainer:
|
||||
randomise_polarisations = self.data_settings.randomise_polarisations
|
||||
polarisations = self.data_settings.polarisations
|
||||
osnr = self.data_settings.osnr
|
||||
# cross_pol_interference = self.data_settings.cross_pol_interference
|
||||
if override is not None:
|
||||
num_symbols = override.get("num_symbols", None)
|
||||
config_path = override.get("config_path", config_path)
|
||||
polarisations = override.get("polarisations", polarisations)
|
||||
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
|
||||
# cross_pol_interference = override.get("angle_var", 0)
|
||||
# get dataset
|
||||
dataset = FiberRegenerationDataset(
|
||||
file_path=config_path,
|
||||
@@ -768,11 +843,13 @@ class RegenerationTrainer:
|
||||
target_delay=in_out_delay,
|
||||
xy_delay=xy_delay,
|
||||
drop_first=self.data_settings.drop_first,
|
||||
drop_last=self.data_settings.drop_last,
|
||||
dtype=dtype,
|
||||
real=not dtype.is_complex,
|
||||
num_symbols=num_symbols,
|
||||
randomise_polarisations=randomise_polarisations,
|
||||
polarisations=polarisations,
|
||||
# cross_pol_interference=cross_pol_interference,
|
||||
osnr = osnr,
|
||||
)
|
||||
|
||||
@@ -842,8 +919,10 @@ class RegenerationTrainer:
|
||||
running_loss = 0.0
|
||||
self.model.train()
|
||||
loader_len = len(train_loader)
|
||||
x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
||||
y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
||||
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
||||
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
||||
x_key = "x"
|
||||
y_key = "y"
|
||||
for batch_idx, batch in enumerate(train_loader):
|
||||
x = batch[x_key]
|
||||
y = batch[y_key]
|
||||
@@ -855,7 +934,10 @@ class RegenerationTrainer:
|
||||
angle.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = self.model(x, -angle)
|
||||
# loss = util.complexNN.complex_mse_loss(y_pred*torch.sqrt(torch.ones(1, device=y.device)*5), y*torch.sqrt(torch.ones(1, device=y.device)*5), power=True)
|
||||
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||
|
||||
|
||||
loss_value = loss.item()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
@@ -898,8 +980,10 @@ class RegenerationTrainer:
|
||||
|
||||
self.model.eval()
|
||||
running_error = 0
|
||||
x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
||||
y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
||||
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
||||
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
||||
x_key = "x"
|
||||
y_key = "y"
|
||||
with torch.no_grad():
|
||||
for _, batch in enumerate(valid_loader):
|
||||
x = batch[x_key]
|
||||
@@ -911,7 +995,9 @@ class RegenerationTrainer:
|
||||
angle.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = self.model(x, -angle)
|
||||
# error = util.complexNN.complex_mse_loss(y_pred*torch.sqrt(torch.ones(1, device=y.device)*5), y*torch.sqrt(torch.ones(1, device=y.device)*5), power=True)
|
||||
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||
|
||||
error_value = error.item()
|
||||
running_error += error_value
|
||||
|
||||
@@ -928,7 +1014,7 @@ class RegenerationTrainer:
|
||||
if (epoch + 1) % 10 == 0 or epoch < 10:
|
||||
# plotting is slow, so only do it every 10 epochs
|
||||
title_append, subtitle = self.build_title(epoch + 1)
|
||||
head_fig, eye_fig, powers_fig = self.plot_model_response(
|
||||
head_fig, eye_fig, weight_fig, powers_fig = self.plot_model_response(
|
||||
model=self.model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
@@ -944,6 +1030,11 @@ class RegenerationTrainer:
|
||||
eye_fig,
|
||||
epoch + 1,
|
||||
)
|
||||
self.writer.add_figure(
|
||||
"weights",
|
||||
weight_fig,
|
||||
epoch + 1,
|
||||
)
|
||||
|
||||
self.writer.add_figure(
|
||||
"powers",
|
||||
@@ -967,9 +1058,10 @@ class RegenerationTrainer:
|
||||
regen = []
|
||||
timestamps = []
|
||||
angles = []
|
||||
x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
||||
y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
||||
|
||||
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
||||
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
||||
x_key = "x"
|
||||
y_key = "y"
|
||||
with torch.no_grad():
|
||||
model = model.to(self.pytorch_settings.device)
|
||||
for batch in loader:
|
||||
@@ -1056,7 +1148,7 @@ class RegenerationTrainer:
|
||||
)
|
||||
|
||||
title_append, subtitle = self.build_title(0)
|
||||
head_fig, eye_fig, powers_fig = self.plot_model_response(
|
||||
head_fig, eye_fig, weight_fig, powers_fig = self.plot_model_response(
|
||||
model=self.model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
@@ -1072,6 +1164,11 @@ class RegenerationTrainer:
|
||||
eye_fig,
|
||||
0,
|
||||
)
|
||||
self.writer.add_figure(
|
||||
"weights",
|
||||
weight_fig,
|
||||
0,
|
||||
)
|
||||
|
||||
self.writer.add_figure(
|
||||
"powers",
|
||||
@@ -1103,6 +1200,9 @@ class RegenerationTrainer:
|
||||
|
||||
train_loader, valid_loader = self.get_sliced_data()
|
||||
|
||||
# train_loader.dataset.fiber_out.to(self.pytorch_settings.device)
|
||||
# train_loader.dataset.fiber_in.to(self.pytorch_settings.device)
|
||||
|
||||
optimizer_name = self.optimizer_settings.optimizer
|
||||
|
||||
# lr = self.optimizer_settings.learning_rate
|
||||
@@ -1132,6 +1232,7 @@ class RegenerationTrainer:
|
||||
# except ValueError:
|
||||
# pass
|
||||
|
||||
self.early_stop_vals = {"min_loss": float("inf"), "plateau_cnt": 0}
|
||||
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
|
||||
enable_progress = True
|
||||
if enable_progress:
|
||||
@@ -1147,29 +1248,64 @@ class RegenerationTrainer:
|
||||
epoch,
|
||||
enable_progress=enable_progress,
|
||||
)
|
||||
if self.early_stop(loss):
|
||||
self.save_model_checkpoints(epoch, loss)
|
||||
break
|
||||
if self.optimizer_settings.scheduler is not None:
|
||||
self.scheduler.step(loss)
|
||||
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch)
|
||||
if self.pytorch_settings.save_models and self.model is not None:
|
||||
save_path = (
|
||||
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
|
||||
)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
checkpoint = self.build_checkpoint_dict(loss, epoch)
|
||||
self.save_checkpoint(checkpoint, save_path)
|
||||
|
||||
if loss < self.best["loss"]:
|
||||
self.best = checkpoint
|
||||
save_path = (
|
||||
Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar"
|
||||
)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.save_checkpoint(self.best, save_path)
|
||||
self.save_model_checkpoints(epoch, loss)
|
||||
self.writer.flush()
|
||||
|
||||
save_path = (Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar")
|
||||
print(f"Training complete. Best checkpoint: {save_path}")
|
||||
self.writer.close()
|
||||
return self.best
|
||||
|
||||
def early_stop(self, loss):
|
||||
# not stopping early at all
|
||||
if not self.optimizer_settings.early_stopping:
|
||||
return False
|
||||
|
||||
# stopping because of abs threshold
|
||||
if (loss_thr := self.optimizer_settings.early_stop_kwargs.get("threshold", None)) is not None:
|
||||
if loss <= loss_thr:
|
||||
print(f"Early stop: loss is below threshold ({loss:.2e} <= {loss_thr:.2e})")
|
||||
return True
|
||||
|
||||
# update vals
|
||||
if loss < self.early_stop_vals["min_loss"]:
|
||||
self.early_stop_vals["min_loss"] = loss
|
||||
self.early_stop_vals["plateau_cnt"] = 0
|
||||
return False
|
||||
|
||||
# stopping because of plateau
|
||||
if (plateau_thresh := self.optimizer_settings.early_stop_kwargs.get("plateau", None)) is not None:
|
||||
self.early_stop_vals["plateau_cnt"] += 1
|
||||
if self.early_stop_vals["plateau_cnt"] >= plateau_thresh:
|
||||
print(f"Early stop: loss plateau length over threshold ({self.early_stop_vals["plateau_cnt"]} >= {plateau_thresh})")
|
||||
return True
|
||||
|
||||
# no stop
|
||||
return False
|
||||
|
||||
def save_model_checkpoints(self, epoch, loss):
|
||||
if self.pytorch_settings.save_models and self.model is not None:
|
||||
save_path = (
|
||||
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
|
||||
)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
checkpoint = self.build_checkpoint_dict(loss, epoch)
|
||||
self.save_checkpoint(checkpoint, save_path)
|
||||
|
||||
if loss < self.best["loss"]:
|
||||
self.best = checkpoint
|
||||
save_path = (
|
||||
Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar"
|
||||
)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.save_checkpoint(self.best, save_path)
|
||||
|
||||
def _plot_model_response_powers(self, powers, layer_names, title_append="", subtitle="", show=True):
|
||||
powers = [power / powers[0] for power in powers]
|
||||
fig, ax = plt.subplots()
|
||||
@@ -1190,6 +1326,77 @@ class RegenerationTrainer:
|
||||
plt.show()
|
||||
return fig
|
||||
|
||||
def _plot_model_weights(self, model, title_append="", subtitle="", show=True):
|
||||
model_params = []
|
||||
plots = []
|
||||
dims = []
|
||||
for num, (layer_name, layer) in enumerate(model.named_children()):
|
||||
onn_weights = layer.ONN.weight
|
||||
onn_weights = onn_weights.detach().cpu().numpy()
|
||||
onn_values = np.abs(onn_weights).real
|
||||
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
|
||||
|
||||
model_params.append({layer_name: onn_weights})
|
||||
plots.append({layer_name: (num, onn_values, onn_angles)})
|
||||
dims.append(onn_weights.shape[0])
|
||||
|
||||
max_size = np.max(dims)
|
||||
|
||||
for plot in plots:
|
||||
layer_name, (num, onn_values, onn_angles) = plot.popitem()
|
||||
|
||||
if num == 0:
|
||||
value_img = onn_values
|
||||
angle_img = onn_angles
|
||||
onn_angles = pad_to_size(onn_angles, (max_size, None))
|
||||
onn_values = pad_to_size(onn_values, (max_size, None))
|
||||
else:
|
||||
onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1))
|
||||
onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1))
|
||||
value_img = np.concatenate((value_img, onn_values), axis=1)
|
||||
angle_img = np.concatenate((angle_img, onn_angles), axis=1)
|
||||
|
||||
value_img = np.ma.array(value_img, mask=np.isnan(value_img))
|
||||
angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img))
|
||||
|
||||
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(18, 6.5))
|
||||
fig.tight_layout()
|
||||
|
||||
dividers = map(make_axes_locatable, axs)
|
||||
caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers))
|
||||
|
||||
masked_value_img = value_img
|
||||
cmap = cm.batlow
|
||||
cmap.set_bad(color="#AAAAAA")
|
||||
im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1)
|
||||
fig.colorbar(im_val, cax=caxs[0], orientation="vertical")
|
||||
|
||||
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
|
||||
cmap = cm.romaO
|
||||
cmap.set_bad(color="#AAAAAA")
|
||||
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi)
|
||||
cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)])
|
||||
cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", "2π"])
|
||||
|
||||
|
||||
axs[0].axis("off")
|
||||
axs[1].axis("off")
|
||||
|
||||
axs[0].set_title("Values")
|
||||
axs[1].set_title("Angles")
|
||||
|
||||
title = "Layer Weights"
|
||||
if title_append:
|
||||
title += f" {title_append}"
|
||||
if subtitle:
|
||||
title += f"\n{subtitle}"
|
||||
fig.suptitle(title)
|
||||
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
return fig
|
||||
|
||||
def _plot_model_response_eye(
|
||||
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
|
||||
):
|
||||
@@ -1354,7 +1561,7 @@ class RegenerationTrainer:
|
||||
|
||||
data_settings_backup = copy.deepcopy(self.data_settings)
|
||||
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
||||
self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
|
||||
self.data_settings.drop_first = int(64 + random.randint(0, 1000))
|
||||
self.data_settings.shuffle = False
|
||||
self.data_settings.train_split = 1.0
|
||||
self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols)
|
||||
@@ -1363,7 +1570,7 @@ class RegenerationTrainer:
|
||||
if isinstance(self.data_settings.config_path, (list, tuple))
|
||||
else self.data_settings.config_path
|
||||
)
|
||||
fiber_length = int(float(str(config_path).split("-")[4]) / 1000)
|
||||
# fiber_length = int(float(str(config_path).split("-")[4]) / 1000)
|
||||
if not hasattr(self, "_plot_loader"):
|
||||
self._plot_loader, _ = self.get_sliced_data(
|
||||
override={
|
||||
@@ -1376,6 +1583,7 @@ class RegenerationTrainer:
|
||||
}
|
||||
)
|
||||
self._sps = self._plot_loader.dataset.samples_per_symbol
|
||||
fiber_length = float(self._plot_loader.dataset.config["fiber"]["length"])/1000
|
||||
self.data_settings = data_settings_backup
|
||||
self.pytorch_settings = pytorch_settings_backup
|
||||
|
||||
@@ -1403,7 +1611,7 @@ class RegenerationTrainer:
|
||||
import gc
|
||||
|
||||
head_fig = self._plot_model_response_head(
|
||||
fiber_out_rot[: self.pytorch_settings.head_symbols * self._sps],
|
||||
fiber_out[: self.pytorch_settings.head_symbols * self._sps],
|
||||
fiber_in[: self.pytorch_settings.head_symbols * self._sps],
|
||||
regen[: self.pytorch_settings.head_symbols * self._sps],
|
||||
angles[: self.pytorch_settings.head_symbols * self._sps],
|
||||
@@ -1417,7 +1625,7 @@ class RegenerationTrainer:
|
||||
# raise NotImplementedError("Eye diagram not implemented")
|
||||
eye_fig = self._plot_model_response_eye(
|
||||
fiber_in[: self.pytorch_settings.eye_symbols * self._sps],
|
||||
fiber_out_rot[: self.pytorch_settings.eye_symbols * self._sps],
|
||||
fiber_out[: self.pytorch_settings.eye_symbols * self._sps],
|
||||
regen[: self.pytorch_settings.eye_symbols * self._sps],
|
||||
timestamps=timestamps[: self.pytorch_settings.eye_symbols * self._sps],
|
||||
labels=("fiber in", "fiber out", "regen"),
|
||||
@@ -1426,9 +1634,11 @@ class RegenerationTrainer:
|
||||
subtitle=subtitle,
|
||||
show=show,
|
||||
)
|
||||
|
||||
weight_fig = self._plot_model_weights(model, title_append=title_append, subtitle=subtitle, show=show)
|
||||
gc.collect()
|
||||
|
||||
return head_fig, eye_fig, power_fig
|
||||
return head_fig, eye_fig, weight_fig, power_fig
|
||||
|
||||
def build_title(self, number: int):
|
||||
title_append = f"epoch {number}"
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -25,7 +27,29 @@ from hypertraining import models
|
||||
# ),
|
||||
# constant_values=(-np.inf, -np.inf),
|
||||
# )
|
||||
|
||||
|
||||
def register_puccs_cmap(puccs_path=None):
|
||||
puccs_path = Path(__file__).resolve().parent / 'puccs.csv' if puccs_path is None else puccs_path
|
||||
|
||||
colors = []
|
||||
# keys = None
|
||||
with open(puccs_path, "r") as f:
|
||||
for i, line in enumerate(f.readlines()):
|
||||
elements = tuple(line.split(","))
|
||||
# if i == 0:
|
||||
# # keys = elements
|
||||
# continue
|
||||
# else:
|
||||
try:
|
||||
colors.append(tuple(map(float, elements[4:])))
|
||||
except ValueError:
|
||||
continue
|
||||
# colors = []
|
||||
# for current in puccs_csv_data:
|
||||
# colors.append(tuple(current[4:]))
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
import matplotlib as mpl
|
||||
mpl.colormaps.register(LinearSegmentedColormap.from_list('puccs', colors))
|
||||
|
||||
def pad_to_size(array, size):
|
||||
if not hasattr(size, "__len__"):
|
||||
@@ -65,7 +89,7 @@ def pad_to_size(array, size):
|
||||
constant_values=(np.nan, np.nan),
|
||||
)
|
||||
|
||||
def model_plot(model_path):
|
||||
def model_plot(model_path, show=True):
|
||||
torch.serialization.add_safe_globals([
|
||||
*util.complexNN.__all__,
|
||||
GlobalSettings,
|
||||
@@ -81,173 +105,113 @@ def model_plot(model_path):
|
||||
dims = checkpoint_dict["model_kwargs"].pop("dims")
|
||||
|
||||
model = models.regenerator(*dims, **checkpoint_dict["model_kwargs"])
|
||||
model.load_state_dict(checkpoint_dict["model_state_dict"])
|
||||
model.load_state_dict(checkpoint_dict["model_state_dict"], strict=False)
|
||||
|
||||
model_params = []
|
||||
plots = []
|
||||
max_size = np.max(dims)
|
||||
# max_act_size = np.max(dims[1:])
|
||||
|
||||
angles = [None, None]
|
||||
weights = [None, None]
|
||||
# angles = [None, None]
|
||||
# weights = [None, None]
|
||||
|
||||
for num, (layer_name, layer) in enumerate(model.named_children()):
|
||||
# each layer contains an "ONN" layer and an "activation" layer
|
||||
# activation layer is approximately the same for all layers and nodes -> rotation by 90 degrees
|
||||
onn_weights = layer.ONN.weight.T
|
||||
onn_weights = layer.ONN.weight
|
||||
onn_weights = onn_weights.detach().cpu().numpy()
|
||||
onn_values = np.abs(onn_weights).real
|
||||
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
|
||||
|
||||
|
||||
act = layer.activation
|
||||
|
||||
act_values = np.ones((act.size, 1))
|
||||
|
||||
act_values = np.nan * act_values
|
||||
|
||||
act_angles = act.phase.unsqueeze(-1).detach().cpu().numpy()
|
||||
...
|
||||
# act_phi_bias = torch.pi * act.V_bias / (act.V_pi + 1e-8)
|
||||
# act_phi_gain = torch.pi * (act.alpha * act.gain * act.responsivity) / (act.V_pi + 1e-8)
|
||||
# xs = (0.01, 0.1, 1)
|
||||
|
||||
# act_values = np.zeros((act.size, len(xs)*2))
|
||||
# act_angles = np.zeros((act.size, len(xs)*2))
|
||||
|
||||
# act_values[:,:] = np.nan
|
||||
# act_angles[:,:] = np.nan
|
||||
|
||||
# for xi, x in enumerate(xs):
|
||||
# phi_intermediate = act_phi_gain * x**2 + act_phi_bias
|
||||
|
||||
# act_resulting_gain = (
|
||||
# 1j
|
||||
# * torch.sqrt(1-act.alpha)
|
||||
# * torch.exp(-0.5j * phi_intermediate)
|
||||
# * torch.cos(0.5 * phi_intermediate)
|
||||
# * x
|
||||
# )
|
||||
|
||||
# act_resulting_gain = act_resulting_gain.detach().cpu().numpy()
|
||||
# act_values[:, xi*2] = np.abs(act_resulting_gain).real
|
||||
# act_angles[:, xi*2] = np.mod(np.angle(act_resulting_gain), 2*np.pi).real
|
||||
|
||||
|
||||
|
||||
# if angles[0] is None or angles[0] > np.min(onn_angles.flatten()):
|
||||
# angles[0] = np.min(onn_angles.flatten())
|
||||
# if angles[1] is None or angles[1] < np.max(onn_angles.flatten()):
|
||||
# angles[1] = np.max(onn_angles.flatten())
|
||||
# if weights[0] is None or weights[0] > np.min(onn_weights.flatten()):
|
||||
# weights[0] = np.min(onn_weights.flatten())
|
||||
# if weights[1] is None or weights[1] < np.max(onn_weights.flatten()):
|
||||
# weights[1] = np.max(onn_weights.flatten())
|
||||
|
||||
model_params.append({layer_name: onn_weights})
|
||||
plots.append({layer_name: (num, onn_values, onn_angles, act_values, act_angles)})
|
||||
plots.append({layer_name: (num, onn_values, onn_angles)})#, act_values, act_angles)})
|
||||
|
||||
# fig, axs = plt.subplots(3, len(model_params)*2-1, figsize=(20, 5))
|
||||
|
||||
for plot in plots:
|
||||
layer_name, (num, onn_values, onn_angles, act_values, act_angles) = plot.popitem()
|
||||
# for_plot[:, :, 0] = (for_plot[:, :, 0] - angles[0]) / (angles[1] - angles[0])
|
||||
# for_plot[:, :, 1] = (for_plot[:, :, 1] - weights[0]) / (weights[1] - weights[0])
|
||||
|
||||
onn_values = np.ma.array(onn_values, mask=np.isnan(onn_values))
|
||||
onn_values = onn_values - np.min(onn_values)
|
||||
onn_values = onn_values / np.max(onn_values)
|
||||
|
||||
act_values = np.ma.array(act_values, mask=np.isnan(act_values))
|
||||
act_values = act_values - np.min(act_values)
|
||||
act_values = act_values / np.max(act_values)
|
||||
|
||||
|
||||
onn_values = onn_values
|
||||
onn_values = pad_to_size(onn_values, (max_size, None))
|
||||
|
||||
act_values = act_values
|
||||
act_values = pad_to_size(act_values, (max_size, 3))
|
||||
|
||||
onn_angles = onn_angles / np.pi
|
||||
onn_angles = pad_to_size(onn_angles, (max_size, None))
|
||||
|
||||
act_angles = act_angles / np.pi
|
||||
act_angles = pad_to_size(act_angles, (max_size, 3))
|
||||
|
||||
|
||||
|
||||
# onn_angles = onn_angles - np.min(onn_angles)
|
||||
# onn_angles = onn_angles / np.max(onn_angles)
|
||||
|
||||
# act_angles = act_angles - np.min(act_angles)
|
||||
# act_angles = act_angles / np.max(act_angles)
|
||||
layer_name, (num, onn_values, onn_angles) = plot.popitem()
|
||||
|
||||
if num == 0:
|
||||
value_img = np.concatenate((onn_values, act_values), axis=1)
|
||||
angle_img = np.concatenate((onn_angles, act_angles), axis=1)
|
||||
value_img = onn_values
|
||||
angle_img = onn_angles
|
||||
onn_angles = pad_to_size(onn_angles, (max_size, None))
|
||||
onn_values = pad_to_size(onn_values, (max_size, None))
|
||||
else:
|
||||
value_img = np.concatenate((value_img, onn_values, act_values), axis=1)
|
||||
angle_img = np.concatenate((angle_img, onn_angles, act_angles), axis=1)
|
||||
|
||||
onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1))
|
||||
onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1))
|
||||
value_img = np.concatenate((value_img, onn_values), axis=1)
|
||||
angle_img = np.concatenate((angle_img, onn_angles), axis=1)
|
||||
|
||||
|
||||
value_img = np.ma.array(value_img, mask=np.isnan(value_img))
|
||||
angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img))
|
||||
|
||||
# -np.inf to np.nan
|
||||
# value_img[value_img == -np.inf] = np.nan
|
||||
|
||||
# angle_img += move_to_location_in_size(onn_angles, ((max_size+3)*num, 0), img_overall_size)
|
||||
# angle_img += move_to_location_in_size(act_angles, ((max_size+3)*(num+1) + 2, 0), img_overall_size)
|
||||
|
||||
|
||||
# from cmcrameri import cm
|
||||
from cmap import Colormap as cm
|
||||
import scicomap as sc
|
||||
# from matplotlib import colors as mcolors
|
||||
# alpha_map = mcolors.LinearSegmentedColormap(
|
||||
# 'alphamap',
|
||||
# {
|
||||
# 'red': [(0, 0, 0), (1, 0, 0)],
|
||||
# 'green': [(0, 0, 0), (1, 0, 0)],
|
||||
# 'blue': [(0, 0, 0), (1, 0, 0)],
|
||||
# 'alpha': [
|
||||
# (0, 1, 1),
|
||||
# # (0.2, 0.2, 0.1),
|
||||
# (1, 0, 0)
|
||||
# ]
|
||||
# }
|
||||
# )
|
||||
# alpha_map.set_bad(color="#AAAAAA")
|
||||
|
||||
|
||||
from cmcrameri import cm
|
||||
from matplotlib import colors as mcolors
|
||||
alpha_map = mcolors.LinearSegmentedColormap(
|
||||
'alphamap',
|
||||
{
|
||||
'red': [(0, 0, 0), (1, 0, 0)],
|
||||
'green': [(0, 0, 0), (1, 0, 0)],
|
||||
'blue': [(0, 0, 0), (1, 0, 0)],
|
||||
'alpha': [
|
||||
(0, 1, 1),
|
||||
# (0.2, 0.2, 0.1),
|
||||
(1, 0, 0)
|
||||
]
|
||||
}
|
||||
)
|
||||
alpha_map.set_bad(color="#AAAAAA")
|
||||
|
||||
|
||||
fig, axs = plt.subplots(3, 1, sharex=True, sharey=True, figsize=(7, 8.5))
|
||||
fig.tight_layout()
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
|
||||
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(16, 5))
|
||||
# fig.tight_layout()
|
||||
dividers = map(make_axes_locatable, axs)
|
||||
caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers))
|
||||
# masked_value_img = np.ma.masked_where(np.isnan(value_img), value_img)
|
||||
masked_value_img = value_img
|
||||
cmap = cm.batlowW
|
||||
cmap = cm('google:turbo').to_matplotlib()
|
||||
# cmap = sc.ScicoSequential("rainbow").get_mpl_color_map()
|
||||
cmap.set_bad(color="#AAAAAA")
|
||||
im_val = axs[0].imshow(masked_value_img, cmap=cmap)
|
||||
im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1)
|
||||
fig.colorbar(im_val, cax=caxs[0], orientation="vertical")
|
||||
|
||||
|
||||
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
|
||||
cmap = cm.romaO
|
||||
# cmap = cm('crameri:romao').to_matplotlib()
|
||||
# cmap = plt.get_cmap('puccs')
|
||||
# cmap = sc.ScicoCircular("colorwheel").get_mpl_color_map()
|
||||
cmap = cm('colorcet:CET_C8').to_matplotlib()
|
||||
cmap.set_bad(color="#AAAAAA")
|
||||
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap)
|
||||
im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap)
|
||||
im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map)
|
||||
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi)
|
||||
cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)])
|
||||
cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", "2π"])
|
||||
# im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap)
|
||||
# im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map)
|
||||
|
||||
axs[0].axis("off")
|
||||
axs[1].axis("off")
|
||||
axs[2].axis("off")
|
||||
# axs[2].axis("off")
|
||||
|
||||
axs[0].set_title("Values")
|
||||
axs[1].set_title("Angles")
|
||||
axs[2].set_title("Values and Angles")
|
||||
# axs[2].set_title("Values and Angles")
|
||||
|
||||
|
||||
...
|
||||
plt.show()
|
||||
if show:
|
||||
plt.show()
|
||||
return fig
|
||||
|
||||
# model = models.regenerator(*dims, **model_kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_plot(".models/best_20250105_145719.tar")
|
||||
register_puccs_cmap()
|
||||
if len(sys.argv) > 1:
|
||||
model_plot(sys.argv[1])
|
||||
else:
|
||||
print("Please provide a model path as an argument")
|
||||
# model_plot(".models/best_20250114_224234.tar")
|
||||
|
||||
102
src/single-core-regen/puccs.csv
Normal file
102
src/single-core-regen/puccs.csv
Normal file
@@ -0,0 +1,102 @@
|
||||
"x","L","a","b","R","G","B"
|
||||
0.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
|
||||
0.01,0.5374499525557803,0.604014067614707,0.6777967519386492,0.8956274406155226,0.27553288030331824,0.
|
||||
0.02,0.5560867887452998,0.5680836759482211,0.6855816828789898,0.9019507507843885,0.318608215541461,0.
|
||||
0.03,0.5746877595125583,0.5322224300667823,0.6933516322080414,0.907905487190649,0.3580633000693721,0.
|
||||
0.04,0.5932314662487472,0.49647158484797804,0.7010976613543587,0.9134808162089558,0.3949845524063657,0.
|
||||
0.05,0.6117000836392819,0.46086550613202343,0.7088123243737041,0.918668356138916,0.43002019316005363,0.
|
||||
0.06,0.6300828534995973,0.4254249348741487,0.7164911273850869,0.923462736751354,0.4635961938811463,0.
|
||||
0.07,0.6483763163456417,0.3901565406944371,0.7241326253017896,0.9278609626724071,0.49601354353255284,0.
|
||||
0.08,0.6665840140182806,0.3550534951951814,0.7317382976124045,0.9318616057744784,0.5274983630587982,0.
|
||||
0.09,0.6847162776119433,0.3200958808181962,0.7393124597949372,0.9354640163365924,0.5582303922647159,0.
|
||||
0.1,0.7027902128942014,0.2852507189547545,0.7468622572263107,0.9386675557407496,0.5883604892249517,0.004034952213848706
|
||||
0.11,0.7208298719332069,0.25047163906104203,0.7543977368741345,0.9414708123927996,0.6180221032545026,0.016031521294251994
|
||||
0.12,0.7388665670611175,0.2156982733607376,0.7619319784446927,0.943870754968487,0.6473392272576862,0.029857267582036696
|
||||
0.13,0.7569392765472108,0.18085547473834482,0.7694812638396673,0.9458617774020323,0.676432172396153,0.045365670193636125
|
||||
0.14,0.7750950944867471,0.14585244938794778,0.7770652650825484,0.9474345911958609,0.7054219201084561,0.06017985923530026
|
||||
0.15,0.793389684293558,0.11058188251425949,0.7847072337503834,0.9485749196617762,0.7344334940032564,0.07418869502646075
|
||||
0.16,0.8117919447684838,0.07510373484536464,0.792394178330817,0.9492596163836376,0.7634480277996188,0.08767517868137237
|
||||
0.17,0.8293050962981561,0.03629277424762101,0.799038155466063,0.9462308253550155,0.7922009241807345,0.10066327128139077
|
||||
0.18,0.8213303100752708,-0.0062517290795987,0.7879999288492758,0.9088702681901394,0.7940579017644396,0.10139639009534024
|
||||
0.19,0.8134831311534617,-0.048115463155645855,0.7771383286984362,0.8716809050191757,0.7954897210083888,0.10232311621802098
|
||||
0.2,0.80558613530069,-0.0902449644291895,0.7662077749032042,0.8337524177888596,0.7965471523787845,0.10344968926026826
|
||||
0.21,0.7975860185564765,-0.13292460297117392,0.7551344872795225,0.7947193410849823,0.7972381033243311,0.10477682283894393
|
||||
0.22,0.7894147026971006,-0.17651756772919341,0.7438242359834689,0.7540941866826836,0.7975605026647324,0.10631182441371936
|
||||
0.23,0.7809997374598548,-0.2214103719409295,0.7321767396537806,0.7112894518675287,0.7974995317311054,0.1080672415170634
|
||||
0.24,0.7722646970273015,-0.2680107379394189,0.7200862142018722,0.6655745739336695,0.7970267795229349,0.11006041388465265
|
||||
0.25,0.7631307298557146,-0.3167393290089981,0.7074435179925446,0.6160047476007512,0.7960993904970947,0.11231257117602686
|
||||
0.26,0.7535192192483822,-0.36801555555407994,0.6941398344519211,0.5612859274945571,0.794659599537827,0.11484733363789801
|
||||
0.27,0.7433557597838075,-0.42223636134393283,0.6800721760037781,0.4994862901720824,0.7926351396848288,0.11768844813479104
|
||||
0.28,0.732575139048096,-0.479749646583324,0.6651502794883674,0.42731393423789277,0.7899410218414098,0.12085678487511567
|
||||
0.29,0.7211269294461059,-0.5408244362880141,0.6493043460161184,0.3378265607222193,0.786483110019224,0.124366774034814
|
||||
0.3,0.7090756028785993,-0.6051167807996883,0.6326236137723747,0.2098475715121697,0.7821998608677176,0.12819222127525928
|
||||
0.31,0.7094510768540225,-0.6165036055456403,0.5630307498747129,0.15061488620640032,0.7845112116922692,0.21943537230975235
|
||||
0.32,0.7174669421288304,-0.5917687864932311,0.4797229624661701,0.18766933782916642,0.7905828987725732,0.31091344246312086
|
||||
0.33,0.7249009746435938,-0.5688293479200438,0.40246208306061504,0.21160609617940718,0.7962175427587832,0.38519766326885596
|
||||
0.34,0.7317072855135611,-0.5478268906666535,0.3317250285377912,0.22717569971119178,0.8013847719431052,0.4490960048955565
|
||||
0.35,0.7379328517830899,-0.5286164561226088,0.26702357292455026,0.23690087622812972,0.8061220291668977,0.5056371468159843
|
||||
0.36,0.7436229063122554,-0.5110584677642499,0.20788761731555405,0.24226377668817778,0.8104638164122776,0.5563570758573497
|
||||
0.37,0.7488251728809415,-0.4950056627547577,0.15382117501783654,0.24424372086048424,0.8144455902164638,0.6022301663745243
|
||||
0.38,0.7535943992285348,-0.48028910419451787,0.10425526029155024,0.24352232677523483,0.818107753931944,0.6440238320299774
|
||||
0.39,0.757994865186593,-0.4667104416936734,0.05852182167144754,0.240562414747303,0.8214980148949816,0.6824536572462205
|
||||
0.4,0.7620994844391137,-0.4540446830999986,0.015863077249098356,0.2356325204239052,0.8246710357361025,0.7182393675419642
|
||||
0.41,0.7659871096124125,-0.4420485102716773,-0.024540477496154123,0.22880568593963535,0.8276865975886148,0.7521146815529202
|
||||
0.42,0.7697410958994951,-0.4304647113488041,-0.06355514164248566,0.21993360985514526,0.8306086550266585,0.7848331944479765
|
||||
0.43,0.773446484628189,-0.4190308715098135,-0.10206473803580057,0.20858849290850018,0.833503273690861,0.8171544357676854
|
||||
0.44,0.7771893686864673,-0.4074813310994203,-0.14096401824224686,0.1939295692427068,0.8364382500400466,0.8498448067259188
|
||||
0.45,0.7810574093604746,-0.3955455908045306,-0.18116403397486242,0.17438366103820427,0.839483669055626,0.8836865023336339
|
||||
0.46,0.7851360804917298,-0.3829599011818591,-0.2235531031349741,0.14679145002531463,0.8427091517444469,0.9194481212717681
|
||||
0.47,0.789525027020907,-0.369416784561489,-0.26916682191206776,0.10278921007810798,0.8461971304126237,0.9580316568065935
|
||||
0.48,0.7942371698732826,-0.35487637041943493,-0.3181394757087982,0.0013920913109500188,0.8499626968466341,0.9995866371771526
|
||||
0.49,0.7773897680996302,-0.31852357140025195,-0.34537976514700053,0.10740420703601522,0.8254781216972907,1.
|
||||
0.5,0.7604011244310231,-0.28211213216592784,-0.3722846952738428,0.1581725581872408,0.8008522647497104,1.
|
||||
0.51,0.7433440454962605,-0.2455540169176899,-0.3992980063927199,0.19300141807932156,0.7761561224913385,1.
|
||||
0.52,0.7262590833969331,-0.20893614020926626,-0.42635547610418184,0.2194621842292243,0.751443124097109,1.
|
||||
0.53,0.709058602701224,-0.17207067467417486,-0.453595892719742,0.2405673704012788,0.7265803324554873,1.
|
||||
0.54,0.6915768892539101,-0.1346024482921609,-0.48128169789479536,0.25788347992973676,0.701321051230534,1.
|
||||
0.55,0.6736331627810209,-0.09614399811510127,-0.5096991935104321,0.2722888922216317,0.6753950894563805,1.
|
||||
0.56,0.6551463184003872,-0.05652149358027936,-0.5389768254408652,0.28422807900785235,0.6486730893521468,1.
|
||||
0.57,0.6361671326276888,-0.01584376303510615,-0.5690341788729347,0.293907374075009,0.6212117649042732,1.
|
||||
0.58,0.6168396823565967,0.025580396234342995,-0.5996430791016598,0.301442767979156,0.5931976878638505,1.
|
||||
0.59,0.5973210287815495,0.06741435793529688,-0.6305547881733555,0.30694603901024253,0.5648312189065924,1.
|
||||
0.6,0.5777303704171711,0.10940264614179468,-0.661580531294122,0.3105418468883679,0.5362525958007331,1.
|
||||
0.61,0.5581475370499237,0.15137416317967575,-0.6925938819599547,0.3123531986526998,0.5075386530652202,1.
|
||||
0.62,0.5386227795100639,0.19322120739317136,-0.7235152578861672,0.31248922600720636,0.4787151440558522,1.
|
||||
0.63,0.5191666876024412,0.23492108185347996,-0.754327887989376,0.31103663081260624,0.44973844514160927,1.
|
||||
0.64,0.4996990584326256,0.2766456839100268,-0.7851587896650079,0.30803814950244496,0.4204116611935119,1.
|
||||
0.65,0.479957679121191,0.3189570094767831,-0.8164232296840259,0.30343473603466015,0.390226489453496,1.
|
||||
0.66,0.4600072725872886,0.3617163391430824,-0.8480187063016573,0.29717122075330515,0.3591178757512998,1.
|
||||
0.67,0.44600100870220305,0.4113853615984094,-0.8697728377551008,0.3178994129506999,0.3295740682997879,1.
|
||||
0.68,0.4574651571354146,0.44026390446569547,-0.8504539292487465,0.3842479358768364,0.3280946443367561,1.
|
||||
0.69,0.4691809168948424,0.46977626401045774,-0.830711015748157,0.44293649140770447,0.3260767554252525,1.
|
||||
0.7,0.4811696900083858,0.49997635259991063,-0.8105080314416201,0.49708450874457527,0.3234487047238236,1.
|
||||
0.71,0.49350094811609174,0.5310391714342613,-0.7897279055963483,0.5485591109413528,0.3201099534066949,1.
|
||||
0.72,0.5062548753068121,0.5631667067020758,-0.7682355153041539,0.5985798481027601,0.3159263917472715,1.
|
||||
0.73,0.5195243020949684,0.5965928013272943,-0.7458744264238399,0.6480500606439057,0.31071717884730565,1.
|
||||
0.74,0.5334043922713477,0.6315571758288618,-0.7224842728734379,0.6976685401842261,0.3042411890803418,1.
|
||||
0.75,0.5479805812358602,0.6682750446095802,-0.697921082452685,0.7479712773579563,0.29618040787504757,1.
|
||||
0.76,0.5633244502526606,0.7069267230777347,-0.6720642293775535,0.7993701361353484,0.28611136999256687,1.
|
||||
0.77,0.5794956601139,0.7476624986056212,-0.6448131757501174,0.8521918014427678,0.2734527325942473,1.
|
||||
0.78,0.5965429098573916,0.7906050455688622,-0.6160858559672187,0.9067003897516911,0.2573693489198746,1.
|
||||
0.79,0.6145761476424179,0.8360313267658297,-0.5856969899409387,0.963334644317004,0.23648492980159264,1.
|
||||
0.8,0.6232910688128902,0.859291371252556,-0.5300995185388214,1.,0.21867949406239662,0.9712088595948508
|
||||
0.81,0.6159984336377875,0.8439887543380684,-0.44635440435952856,1.,0.21606849746358275,0.9041480210597966
|
||||
0.82,0.6091642745073532,0.8296481879180277,-0.36787420852419694,1.,0.21421830096504035,0.8419706002336461
|
||||
0.83,0.6025478038652375,0.8157644115969636,-0.2918938425681935,1.,0.21295365915197917,0.7823908751330636
|
||||
0.84,0.5961857222953111,0.8024144366282877,-0.21883475834162458,0.9971140114799418,0.21220068235083267,0.7256713129328118
|
||||
0.85,0.5900921771070883,0.7896279492437488,-0.1488594167412921,0.993273906363258,0.2118788857127918,0.671860243327784
|
||||
0.86,0.5842771639541229,0.7774259239818333,-0.08208260304413262,0.9887084084529413,0.21191070453347688,0.6209624706933893
|
||||
0.87,0.578741582584259,0.7658102488427286,-0.018514649521559012,0.9835846378805114,0.2122246941077346,0.5728987835613306
|
||||
0.88,0.5734741590353537,0.7547572669288056,0.04197390858426542,0.9780378159372328,0.21275878699579343,0.5274829957183049
|
||||
0.89,0.5684517008574971,0.7442183119942206,0.09964940221121898,0.9721670725313721,0.21346242315895625,0.4844270603851604
|
||||
0.9,0.5636419856510335,0.7341257696545772,0.15488185789614228,0.9660363209686843,0.21429691147008262,0.4433660148378527
|
||||
0.91,0.5590069340453534,0.7243997354573974,0.20810856081277884,0.9596781387247791,0.2152344151262528,0.4038812338146013
|
||||
0.92,0.5545051525321143,0.7149533506766244,0.25980485409830323,0.9530986696850675,0.21625626438013962,0.3655130449917989
|
||||
0.93,0.5500961975299247,0.705701749880514,0.3104351723857584,0.9462863346513658,0.21735046958786286,0.327780364198278
|
||||
0.94,0.545740378056064,0.6965616468647046,0.36045530782708896,0.93921469089265,0.21851014470332586,0.29014917175372823
|
||||
0.95,0.5414004092067859,0.6874548042588865,0.41029342232076466,0.9318478255642132,0.21973168075163751,0.2519897371806688
|
||||
0.96,0.5370416605957644,0.6783085548415655,0.46034719456417006,0.9241434776436454,0.22101341980094052,0.2124579038400577
|
||||
0.97,0.5326309593934517,0.6690532898786764,0.5109975653738162,0.9160532016485884,0.22235495330179011,0.17018252385769012
|
||||
0.98,0.5281374148557197,0.6596241892863608,0.5625992691950712,0.90752576202319,0.22375597459867458,0.1223073280126531
|
||||
0.99,0.5235317096396147,0.6499597345521199,0.615488972291106,0.8985077346125597,0.22521565729028564,0.05933950582860665
|
||||
1.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
|
||||
|
@@ -26,28 +26,39 @@ global_settings = GlobalSettings(
|
||||
)
|
||||
|
||||
data_settings = DataSettings(
|
||||
# config_path="data/*-128-16384-1-0-0-0-0-PAM4-0-0.ini",
|
||||
config_path="data/20250110-190528-128-16384-100000-0-0.2-17.0-0.058-PAM4-0-0.14-10.ini",
|
||||
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
||||
# config_path="data/20250115-114319-128-16384-inf-100000-0-0-0-0-PAM4-1-0-10.ini", # no effects enabled - baseline
|
||||
# config_path = "data/20250115-233553-128-16384-1060.0-100000-0-0.2-17.0-0.058-PAM4-1.0-0.0-10.ini", # dispersion + slope only
|
||||
# config_path="data/20250115-115836-128-16384-60.0-100000-0-0.2-17-0.058-PAM4-1000-0.2-10.ini", # all linear effects enabled with realistic values + noise + pmd (delta_beta=0.2) + ortho_error = 0.1
|
||||
# config_path="data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # a)
|
||||
# config_path="data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # b)
|
||||
# config_path="data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # c)
|
||||
# config_path="data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # d) 10ps dgd
|
||||
config_path="data/20250120-105720-128-16384-inf-100000-0-0.2-17-0.058-PAM4-0-0-10.ini", # d) 10ns
|
||||
|
||||
# config_path="data/20250114-215547-128-16384-60.0-100000-1.15-0.2-17-0.058-PAM4-1-0-10.ini", # with gamma=1.15, 2.5dBm launch power, no pmd
|
||||
|
||||
|
||||
dtype="complex64",
|
||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||
symbols=4, # study: single_core_regen_20241123_011232
|
||||
symbols=4, # study: single_core_regen_20241123_011232 -> taps spread over 4 symbols @ 10GBd
|
||||
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
|
||||
output_size=20, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||
shuffle=True,
|
||||
drop_first=64,
|
||||
output_size=20, # = model_input_dim/2 study: single_core_regen_20241123_011232
|
||||
shuffle=False,
|
||||
drop_first=256,
|
||||
drop_last=256,
|
||||
train_split=0.8,
|
||||
randomise_polarisations=False,
|
||||
polarisations=True,
|
||||
polarisations=False,
|
||||
# cross_pol_interference=0.01,
|
||||
osnr=16, #16dB due to amplification with NF 5
|
||||
)
|
||||
|
||||
pytorch_settings = PytorchSettings(
|
||||
epochs=1000,
|
||||
batchsize=2**14,
|
||||
batchsize=2**13,
|
||||
device="cuda",
|
||||
dataloader_workers=24,
|
||||
dataloader_prefetch=8,
|
||||
dataloader_workers=32,
|
||||
dataloader_prefetch=4,
|
||||
summary_dir=".runs",
|
||||
write_every=2**5,
|
||||
save_models=True,
|
||||
@@ -65,16 +76,13 @@ model_settings = ModelSettings(
|
||||
# "n_hidden_nodes_3": 4,
|
||||
# "n_hidden_nodes_4": 2,
|
||||
},
|
||||
model_activation_func="phase_shift",
|
||||
model_activation_func="EOActivation",
|
||||
dropout_prob=0,
|
||||
model_layer_function="ONNRect",
|
||||
model_layer_kwargs={"square": True},
|
||||
scale=2.0,
|
||||
model_layer_parametrizations=[
|
||||
{
|
||||
"tensor_name": "weight",
|
||||
"parametrization": util.complexNN.energy_conserving,
|
||||
},
|
||||
# EOactivation
|
||||
{
|
||||
"tensor_name": "alpha",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
@@ -83,54 +91,20 @@ model_settings = ModelSettings(
|
||||
"max": 1,
|
||||
},
|
||||
},
|
||||
# ONNRect
|
||||
{
|
||||
"tensor_name": "gain",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "phase_bias",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": 2 * torch.pi,
|
||||
},
|
||||
"tensor_name": "weight",
|
||||
"parametrization": torch.nn.utils.parametrizations.orthogonal,
|
||||
},
|
||||
# Scale
|
||||
{
|
||||
"tensor_name": "scale",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": 2,
|
||||
"max": 10,
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "angle",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": -torch.pi,
|
||||
"max": torch.pi,
|
||||
},
|
||||
},
|
||||
# {
|
||||
# "tensor_name": "scale",
|
||||
# "parametrization": util.complexNN.clamp,
|
||||
# },
|
||||
# {
|
||||
# "tensor_name": "bias",
|
||||
# "parametrization": util.complexNN.clamp,
|
||||
# },
|
||||
# {
|
||||
# "tensor_name": "V",
|
||||
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
|
||||
# },
|
||||
{
|
||||
"tensor_name": "loss",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@@ -145,191 +119,35 @@ optimizer_settings = OptimizerSettings(
|
||||
scheduler="ReduceLROnPlateau",
|
||||
scheduler_kwargs={
|
||||
"patience": 2**6,
|
||||
"factor": 0.75,
|
||||
"factor": 0.5,
|
||||
# "threshold": 1e-3,
|
||||
"min_lr": 1e-6,
|
||||
"cooldown": 10,
|
||||
},
|
||||
early_stopping=True,
|
||||
early_stop_kwargs={
|
||||
"threshold": 1e-06,
|
||||
"plateau": 2**7,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def save_dict_to_file(dictionary, filename):
|
||||
"""
|
||||
Save the best dictionary to a JSON file.
|
||||
|
||||
:param best: Dictionary containing the best training results.
|
||||
:type best: dict
|
||||
:param filename: Path to the JSON file where the dictionary will be saved.
|
||||
:type filename: str
|
||||
"""
|
||||
with open(filename, "w") as f:
|
||||
json.dump(dictionary, f, indent=4)
|
||||
|
||||
|
||||
def sweep_lengths(*lengths, model=None, data_glob: str = None, strategy="newest"):
|
||||
assert model is not None, "Model must be provided."
|
||||
assert data_glob is not None, "Data glob must be provided."
|
||||
model = model
|
||||
|
||||
fiber_ins = {}
|
||||
fiber_outs = {}
|
||||
regens = {}
|
||||
timestampss = {}
|
||||
|
||||
trainer = RegenerationTrainer(
|
||||
checkpoint_path=model,
|
||||
)
|
||||
trainer.define_model()
|
||||
|
||||
for length in lengths:
|
||||
data_glob_length = data_glob.replace("{length}", str(length))
|
||||
files = list(Path.cwd().glob(data_glob_length))
|
||||
if len(files) == 0:
|
||||
continue
|
||||
if strategy == "newest":
|
||||
sorted_kwargs = {
|
||||
"key": lambda x: x.stat().st_mtime,
|
||||
"reverse": True,
|
||||
}
|
||||
elif strategy == "oldest":
|
||||
sorted_kwargs = {
|
||||
"key": lambda x: x.stat().st_mtime,
|
||||
"reverse": False,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown strategy {strategy}.")
|
||||
file = sorted(files, **sorted_kwargs)[0]
|
||||
|
||||
loader, _ = trainer.get_sliced_data(override={"config_path": file})
|
||||
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
|
||||
|
||||
fiber_ins[length] = fiber_in
|
||||
fiber_outs[length] = fiber_out
|
||||
regens[length] = regen
|
||||
timestampss[length] = timestamps
|
||||
|
||||
data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0])
|
||||
channel_names = ["" for _ in range(2 * len(timestampss.keys()) + 2)]
|
||||
|
||||
data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128
|
||||
data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square()
|
||||
|
||||
channel_names[1] = "fiber in x"
|
||||
|
||||
for li, length in enumerate(timestampss.keys()):
|
||||
data[2 + 2 * li, 0, :] = timestampss[length] / 128
|
||||
data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
|
||||
data[2 + 2 * li + 1, 0, :] = timestampss[length] / 128
|
||||
data[2 + 2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
|
||||
|
||||
channel_names[2 + 2 * li + 1] = f"regen x {length}"
|
||||
channel_names[2 + 2 * li] = f"fiber out x {length}"
|
||||
|
||||
# get current backend
|
||||
backend = matplotlib.get_backend()
|
||||
|
||||
matplotlib.use("TkCairo")
|
||||
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
|
||||
|
||||
print_attrs = ("channel_name", "success", "min_area")
|
||||
with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}):
|
||||
for result in eye.eye_stats:
|
||||
print_dict = {attr: result[attr] for attr in print_attrs}
|
||||
rprint(print_dict)
|
||||
rprint()
|
||||
|
||||
eye.plot(all_stats=False)
|
||||
matplotlib.use(backend)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# lengths = range(90000, 100000+10000, 10000)
|
||||
# lengths = [100000]
|
||||
# sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest")
|
||||
|
||||
|
||||
trainer = RegenerationTrainer(
|
||||
global_settings=global_settings,
|
||||
data_settings=data_settings,
|
||||
pytorch_settings=pytorch_settings,
|
||||
model_settings=model_settings,
|
||||
optimizer_settings=optimizer_settings,
|
||||
# checkpoint_path=".models/best_20250104_191428.tar",
|
||||
reset_epoch=True,
|
||||
# settings_override={
|
||||
# "data_settings": {
|
||||
# "config_path": "data/20241229-163*-128-16384-100000-*.ini",
|
||||
# "polarisations": True,
|
||||
# },
|
||||
# "model_settings": {
|
||||
# "scale": 2.0,
|
||||
# }
|
||||
# }
|
||||
checkpoint_path=".models/best_20250117_144001.tar",
|
||||
new_model=True,
|
||||
settings_override={
|
||||
"data_settings": data_settings.__dict__,
|
||||
# "optimizer_settings": {
|
||||
# "optimizer_kwargs": {
|
||||
# "lr": 0.01,
|
||||
# },
|
||||
# "early_stop_kwargs":{
|
||||
# "plateau": 2**8,
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# 20241202_143149
|
||||
}
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# from hypertraining.lighning_models import regenerator, regeneratorData
|
||||
# import lightning as L
|
||||
|
||||
# model = regenerator(
|
||||
# 2 * data_settings.output_size,
|
||||
# *model_settings.overrides["hidden_layer_dims"],
|
||||
# model_settings.output_dim,
|
||||
# layer_function=getattr(util.complexNN, model_settings.model_layer_function),
|
||||
# layer_func_kwargs=model_settings.model_layer_kwargs,
|
||||
# act_function=getattr(util.complexNN, model_settings.model_activation_func),
|
||||
# act_func_kwargs=None,
|
||||
# parametrizations=model_settings.model_layer_parametrizations,
|
||||
# dtype=getattr(torch, data_settings.dtype),
|
||||
# dropout_prob=model_settings.dropout_prob,
|
||||
# scale_layers=model_settings.scale,
|
||||
# optimizer=getattr(torch.optim, optimizer_settings.optimizer),
|
||||
# optimizer_kwargs=optimizer_settings.optimizer_kwargs,
|
||||
# lr_scheduler=getattr(torch.optim.lr_scheduler, optimizer_settings.scheduler),
|
||||
# lr_scheduler_kwargs=optimizer_settings.scheduler_kwargs,
|
||||
# )
|
||||
|
||||
# dm = regeneratorData(
|
||||
# config_globs=data_settings.config_path,
|
||||
# output_symbols=data_settings.symbols,
|
||||
# output_dim=data_settings.output_size,
|
||||
# dtype=getattr(torch, data_settings.dtype),
|
||||
# drop_first=data_settings.drop_first,
|
||||
# shuffle=data_settings.shuffle,
|
||||
# train_split=data_settings.train_split,
|
||||
# batch_size=pytorch_settings.batchsize,
|
||||
# loader_settings={
|
||||
# "num_workers": pytorch_settings.dataloader_workers,
|
||||
# "prefetch_factor": pytorch_settings.dataloader_prefetch,
|
||||
# "pin_memory": True,
|
||||
# "drop_last": True,
|
||||
# },
|
||||
# seed=global_settings.seed,
|
||||
# )
|
||||
|
||||
# # writer = L.SummaryWriter(pytorch_settings.summary_dir + f"/{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
||||
|
||||
# # from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
# subdir = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
|
||||
# # writer = SummaryWriter(pytorch_settings.summary_dir + f"/{subdir}")
|
||||
|
||||
# logger = L.pytorch.loggers.TensorBoardLogger(pytorch_settings.summary_dir, name=subdir, log_graph=True)
|
||||
|
||||
# trainer = L.Trainer(
|
||||
# fast_dev_run=False,
|
||||
# # max_epochs=pytorch_settings.epochs,
|
||||
# max_epochs=2,
|
||||
# enable_checkpointing=True,
|
||||
# default_root_dir=f".models/{subdir}/",
|
||||
# logger=logger,
|
||||
# )
|
||||
|
||||
# trainer.fit(model, dm)
|
||||
|
||||
@@ -12,6 +12,7 @@ Full license text in LICENSE file
|
||||
"""
|
||||
|
||||
import configparser
|
||||
# import copy
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
@@ -40,7 +41,7 @@ alpha = 0.2
|
||||
D = 17
|
||||
S = 0.058
|
||||
bireflength = 10
|
||||
max_delta_beta = 0.14
|
||||
pmd_q = 0.2
|
||||
; birefseed = 0xC0FFEE
|
||||
|
||||
[signal]
|
||||
@@ -195,10 +196,14 @@ class pam_generator:
|
||||
|
||||
|
||||
def initialize_fiber_and_data(config):
|
||||
f0 = config["glova"].get("f0", None)
|
||||
if f0 is None:
|
||||
f0 = 299792458/(config["glova"].get("lambda0", 1550)*1e-9)
|
||||
config["glova"]["f0"] = f0
|
||||
py_glova = pypho.setup(
|
||||
nos=config["glova"]["nos"],
|
||||
sps=config["glova"]["sps"],
|
||||
f0=config["glova"]["f0"],
|
||||
f0=f0,
|
||||
symbolrate=config["glova"]["symbolrate"],
|
||||
wisdom_dir=config["glova"]["wisdom_dir"],
|
||||
flags=config["glova"]["flags"],
|
||||
@@ -216,7 +221,9 @@ def initialize_fiber_and_data(config):
|
||||
symbolsrc = pypho.symbols(
|
||||
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
|
||||
)
|
||||
laser = pypho.lasmod(py_glova, power=config["signal"]["laser_power"], Df=0, theta=np.pi / 4)
|
||||
laserx = pypho.lasmod(py_glova, power=0, Df=0, theta=np.pi/4)
|
||||
# lasery = pypho.lasmod(py_glova, power=0, Df=25, theta=0)
|
||||
|
||||
modulator = pam_generator(
|
||||
py_glova,
|
||||
mod_depth=config["signal"]["mod_depth"],
|
||||
@@ -232,7 +239,12 @@ def initialize_fiber_and_data(config):
|
||||
symbols_y[:3] = 0
|
||||
# symbols_x += 1
|
||||
|
||||
cw = laser()
|
||||
|
||||
cw = laserx()
|
||||
# cwy = lasery()
|
||||
# cw[0]['E'][0] = cw[0]['E'][0]
|
||||
# cw[0]['E'][1] = cwy[0]['E'][0]
|
||||
|
||||
|
||||
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
|
||||
|
||||
@@ -251,13 +263,41 @@ def initialize_fiber_and_data(config):
|
||||
|
||||
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
|
||||
|
||||
nf = py_edfa.NF
|
||||
source_signal = py_edfa(E=source_signal, NF=0)
|
||||
py_edfa.NF = nf
|
||||
## side channels
|
||||
# df = 100
|
||||
# signal_power = pypho.functions.W_to_dBm(np.sum(pypho.functions.getpower_W(source_signal[0]["E"])))
|
||||
|
||||
|
||||
# symbols_x_side = symbolsrc(pattern="random")
|
||||
# symbols_y_side = symbolsrc(pattern="random")
|
||||
# symbols_x_side[:3] = 0
|
||||
# symbols_y_side[:3] = 0
|
||||
|
||||
# cw_left = laser(Df=-df)
|
||||
# source_signal_left = modulator(E=cw_left, symbols=(symbols_x_side, symbols_y_side))
|
||||
|
||||
# cw_right = laser(Df=df)
|
||||
# source_signal_right = modulator(E=cw_right, symbols=(symbols_y_side, symbols_x_side))
|
||||
|
||||
E_in_pure = source_signal[0]["E"]
|
||||
|
||||
nf = py_edfa.NF
|
||||
pmean = py_edfa.Pmean
|
||||
|
||||
# ideal amplification to launch power into fiber
|
||||
source_signal = py_edfa(E=source_signal, NF=0, Pmean=config["signal"]["laser_power"])
|
||||
# source_signal_left = py_edfa(E=source_signal_left, NF=0, Pmean=config["signal"]["laser_power"])
|
||||
# source_signal_right = py_edfa(E=source_signal_right, NF=0, Pmean=config["signal"]["laser_power"])
|
||||
|
||||
# source_signal[0]["E"][0] += source_signal_left[0]["E"][0] + source_signal_right[0]["E"][0]
|
||||
# source_signal[0]["E"][1] += source_signal_left[0]["E"][1] + source_signal_right[0]["E"][1]
|
||||
|
||||
c_data.E_in = source_signal[0]["E"]
|
||||
noise = source_signal[0]["noise"]
|
||||
|
||||
py_edfa.NF = nf
|
||||
py_edfa.Pmean = pmean
|
||||
|
||||
py_fiber = pypho.fiber(
|
||||
glova=py_glova,
|
||||
l=config["fiber"]["length"],
|
||||
@@ -265,20 +305,29 @@ def initialize_fiber_and_data(config):
|
||||
gamma=config["fiber"]["gamma"],
|
||||
D=config["fiber"]["d"],
|
||||
S=config["fiber"]["s"],
|
||||
phi_max=0.02,
|
||||
)
|
||||
if config["fiber"].get("birefsteps", 0) > 0:
|
||||
|
||||
config["fiber"]["birefsteps"] = config["fiber"].get(
|
||||
"birefsteps", config["fiber"]["length"] // config["fiber"].get("bireflength", config["fiber"]["length"])
|
||||
)
|
||||
if config["fiber"]["birefsteps"] > 0:
|
||||
config["fiber"]["bireflength"] = config["fiber"].get("bireflength", config["fiber"]["length"] / config["fiber"]["birefsteps"])
|
||||
seed = config["fiber"].get("birefseed", (int(time.time() * 1000)) % 2**32)
|
||||
py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre(
|
||||
py_fiber.l,
|
||||
py_fiber.l / config["fiber"]["birefsteps"],
|
||||
# maxDeltaD=config["fiber"]["d"]/5,
|
||||
maxDeltaBeta=config["fiber"].get("max_delta_beta", 0),
|
||||
config["fiber"]["length"],
|
||||
config["fiber"]["bireflength"],
|
||||
maxDeltaBeta=config["fiber"].get("pmd_q", 0)/np.sqrt(2*config["fiber"]["bireflength"]),
|
||||
seed=seed,
|
||||
)
|
||||
c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200)
|
||||
elif (dgd := config['fiber'].get('dgd', 0)) > 0:
|
||||
py_fiber.birefarray = [
|
||||
pypho.birefringence_segment(z_point=0, angle=np.pi/2, delta_beta=1000*dgd/config["fiber"]["length"])
|
||||
]
|
||||
c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=config["fiber"]["length"] if py_fiber.gamma == 0 else 200)
|
||||
c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova)
|
||||
|
||||
return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y)
|
||||
return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y), py_glova, E_in_pure
|
||||
|
||||
|
||||
def save_data(data, config, **metadata):
|
||||
@@ -316,8 +365,11 @@ def save_data(data, config, **metadata):
|
||||
f"D = {config['fiber']['d']}",
|
||||
f"S = {config['fiber']['s']}",
|
||||
f"birefsteps = {config['fiber'].get('birefsteps', 0)}",
|
||||
f"max_delta_beta = {config['fiber'].get('max_delta_beta', 0)}",
|
||||
f"pmd_q = {config['fiber'].get('pmd_q', 0)}",
|
||||
f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set",
|
||||
f"dgd = {config['fiber'].get('dgd', 0)}",
|
||||
f"ortho_error = {config['fiber'].get('ortho_error', 0)}",
|
||||
f"pol_error = {config['fiber'].get('pol_error', 0)}",
|
||||
"",
|
||||
"[signal]",
|
||||
f"seed = {hex(seed)}" if seed else "; seed = not set",
|
||||
@@ -346,24 +398,12 @@ def save_data(data, config, **metadata):
|
||||
save_file = f"{config_hash}.h5"
|
||||
config_content += f'"{str(save_file)}"\n'
|
||||
|
||||
filename_components = (
|
||||
timestamp.strftime("%Y%m%d-%H%M%S"),
|
||||
config["glova"]["sps"],
|
||||
config["glova"]["nos"],
|
||||
config["signal"]["osnr"],
|
||||
config["fiber"]["length"],
|
||||
config["fiber"]["gamma"],
|
||||
config["fiber"]["alpha"],
|
||||
config["fiber"]["d"],
|
||||
config["fiber"]["s"],
|
||||
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
|
||||
config["fiber"].get("birefsteps", 0),
|
||||
config["fiber"].get("max_delta_beta", 0),
|
||||
int(config["glova"]["symbolrate"] / 1e9),
|
||||
)
|
||||
config_filename:Path = create_config_filename(config, data_dir, timestamp)
|
||||
while config_filename.exists():
|
||||
time.sleep(1)
|
||||
config_filename = create_config_filename(config, data_dir=data_dir)
|
||||
|
||||
|
||||
lookup_file = "-".join(map(str, filename_components)) + ".ini"
|
||||
config_filename = data_dir / lookup_file
|
||||
with open(config_filename, "w") as f:
|
||||
f.write(config_content)
|
||||
|
||||
@@ -376,11 +416,31 @@ def save_data(data, config, **metadata):
|
||||
outfile.attrs[key] = value
|
||||
# np.save(save_dir / save_file, save_data)
|
||||
|
||||
print("Saved config to", config_filename)
|
||||
print("Saved data to", save_dir / save_file)
|
||||
# print("Saved config to", config_filename)
|
||||
# print("Saved data to", save_dir / save_file)
|
||||
|
||||
return config_filename
|
||||
|
||||
def create_config_filename(config, data_dir:Path, timestamp=None):
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
filename_components = (
|
||||
timestamp.strftime("%Y%m%d-%H%M%S"),
|
||||
config["glova"]["sps"],
|
||||
config["glova"]["nos"],
|
||||
config["signal"]["osnr"],
|
||||
config["fiber"]["length"],
|
||||
config["fiber"]["gamma"],
|
||||
config["fiber"]["alpha"],
|
||||
config["fiber"]["d"],
|
||||
config["fiber"]["s"],
|
||||
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
|
||||
config["fiber"].get("birefsteps", 0),
|
||||
config["fiber"].get("pmd_q", 0),
|
||||
int(config["glova"]["symbolrate"] / 1e9),
|
||||
)
|
||||
lookup_file = "-".join(map(str, filename_components)) + ".ini"
|
||||
return data_dir / lookup_file
|
||||
|
||||
def length_loop(config, lengths, save=True):
|
||||
lengths = sorted(lengths)
|
||||
@@ -388,7 +448,7 @@ def length_loop(config, lengths, save=True):
|
||||
print(f"\nGenerating data for fiber length {length}m")
|
||||
config["fiber"]["length"] = length
|
||||
|
||||
cfiber, cdata, noise, edfa = initialize_fiber_and_data(config)
|
||||
cfiber, cdata, noise, edfa, symbols, py_glova = initialize_fiber_and_data(config)
|
||||
|
||||
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
|
||||
cfiber()
|
||||
@@ -416,51 +476,49 @@ def single_run_with_plot(config, save=True):
|
||||
in_out_eyes(cfiber, cdata, show_pols=False)
|
||||
return config_filename
|
||||
|
||||
def single_run(config, save=True):
|
||||
cfiber, cdata, noise, edfa, symbols = initialize_fiber_and_data(config)
|
||||
|
||||
# mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
|
||||
# print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)")
|
||||
|
||||
# estimate osnr
|
||||
# noise_power = np.mean(noise)
|
||||
# osnr_lin = mean_power_in / noise_power - 1
|
||||
# osnr = 10 * np.log10(osnr_lin)
|
||||
# print(f"Estimated OSNR: {osnr:.3f} dB")
|
||||
def single_run(config, save=True, silent=True):
|
||||
cfiber, cdata, noise, edfa, symbols, glova, E_in = initialize_fiber_and_data(config)
|
||||
|
||||
# transmit
|
||||
cfiber()
|
||||
|
||||
# mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
||||
# print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)")
|
||||
|
||||
# noise = noise * np.exp(-cfiber.params.l * cfiber.params.alpha)
|
||||
|
||||
# estimate osnr
|
||||
# noise_power = np.mean(noise)
|
||||
# osnr_lin = mean_power_out / noise_power - 1
|
||||
# osnr = 10 * np.log10(osnr_lin)
|
||||
# print(f"Estimated OSNR: {osnr:.3f} dB")
|
||||
|
||||
# amplify
|
||||
E_tmp = [{"E": cdata.E_out, "noise": noise}]
|
||||
|
||||
E_tmp = edfa(E=E_tmp)
|
||||
|
||||
|
||||
# rotate
|
||||
# ortho error
|
||||
ortho_error = config["fiber"].get("ortho_error", 0)
|
||||
|
||||
E_tmp[0]["E"] = np.stack((
|
||||
E_tmp[0]["E"][0] * np.cos(ortho_error/2) + E_tmp[0]["E"][1] * np.sin(ortho_error/2),
|
||||
E_tmp[0]["E"][0] * np.sin(ortho_error/2) + E_tmp[0]["E"][1] * np.cos(ortho_error/2)
|
||||
), axis=0)
|
||||
|
||||
|
||||
pol_error = config['fiber'].get('pol_error', 0)
|
||||
|
||||
E_tmp[0]["E"] = np.stack((
|
||||
E_tmp[0]["E"][0] * np.cos(pol_error) - E_tmp[0]["E"][1] * np.sin(pol_error),
|
||||
E_tmp[0]["E"][0] * np.sin(pol_error) + E_tmp[0]["E"][1] * np.cos(pol_error)
|
||||
), axis=0)
|
||||
|
||||
|
||||
|
||||
|
||||
# output
|
||||
cdata.E_out = E_tmp[0]["E"]
|
||||
# noise = E_tmp[0]["noise"]
|
||||
|
||||
# mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
||||
|
||||
# print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)")
|
||||
|
||||
# estimate osnr
|
||||
# noise_power = np.mean(noise)
|
||||
# osnr_lin = mean_power_amp / noise_power - 1
|
||||
# osnr = 10 * np.log10(osnr_lin)
|
||||
# print(f"Estimated OSNR: {osnr:.3f} dB")
|
||||
|
||||
config_filename = None
|
||||
symbols = np.array(symbols)
|
||||
if save:
|
||||
config_filename = save_data(cdata, config, **{"symbols": symbols})
|
||||
return cfiber,cdata,config_filename
|
||||
if not silent:
|
||||
print(f"Saved config to {config_filename}")
|
||||
return cfiber, cdata, config_filename
|
||||
|
||||
|
||||
def in_out_eyes(cfiber, cdata, show_pols=False):
|
||||
|
||||
138
src/single-core-regen/testing/prob_dens.ipynb
Normal file
138
src/single-core-regen/testing/prob_dens.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -441,8 +441,7 @@ class input_rotator(nn.Module):
|
||||
# return out
|
||||
|
||||
|
||||
#### as defined by zhang et al
|
||||
|
||||
#### as defined by zhang et alas
|
||||
|
||||
class DropoutComplex(nn.Module):
|
||||
def __init__(self, p=0.5):
|
||||
@@ -464,7 +463,7 @@ class Scale(nn.Module):
|
||||
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.scale
|
||||
return x * torch.sqrt(self.scale)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Scale({self.size})"
|
||||
@@ -546,35 +545,31 @@ class EOActivation(nn.Module):
|
||||
raise ValueError("Size must be specified")
|
||||
self.size = size
|
||||
self.alpha = nn.Parameter(torch.rand(size))
|
||||
self.V_bias = nn.Parameter(torch.rand(size))
|
||||
self.gain = nn.Parameter(torch.rand(size))
|
||||
# if bias:
|
||||
# self.phase_bias = nn.Parameter(torch.zeros(size))
|
||||
# else:
|
||||
# self.register_buffer("phase_bias", torch.zeros(size))
|
||||
# self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
|
||||
self.register_buffer("responsivity", torch.ones(size)*0.9)
|
||||
self.register_buffer("V_pi", torch.ones(size)*3)
|
||||
self.V_bias = nn.Parameter(torch.rand(size))
|
||||
# self.register_buffer("gain", torch.ones(size))
|
||||
# self.register_buffer("responsivity", torch.ones(size))
|
||||
# self.register_buffer("V_pi", torch.ones(size))
|
||||
|
||||
self.reset_weights()
|
||||
|
||||
def reset_weights(self):
|
||||
if "alpha" in self._parameters:
|
||||
self.alpha.data = torch.rand(self.size)
|
||||
if "V_pi" in self._parameters:
|
||||
self.V_pi.data = torch.rand(self.size)*3
|
||||
# if "V_pi" in self._parameters:
|
||||
# self.V_pi.data = torch.rand(self.size)*3
|
||||
if "V_bias" in self._parameters:
|
||||
self.V_bias.data = torch.randn(self.size)
|
||||
if "gain" in self._parameters:
|
||||
self.gain.data = torch.rand(self.size)
|
||||
if "responsivity" in self._parameters:
|
||||
self.responsivity.data = torch.ones(self.size)*0.9
|
||||
# if "responsivity" in self._parameters:
|
||||
# self.responsivity.data = torch.ones(self.size)*0.9
|
||||
# if "bias" in self._parameters:
|
||||
# self.phase_bias.data = torch.zeros(self.size)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
|
||||
g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8)
|
||||
phi_b = torch.pi * self.V_bias# / (self.V_pi)
|
||||
g_phi = torch.pi * (self.alpha * self.gain)# * self.responsivity)# / (self.V_pi)
|
||||
intermediate = g_phi * x.abs().square() + phi_b
|
||||
return (
|
||||
1j
|
||||
|
||||
105
src/single-core-regen/util/core.py
Normal file
105
src/single-core-regen/util/core.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
|
||||
# This software is licensed according to the "BSD 2-clause" license.
|
||||
|
||||
import hashlib
|
||||
import h5py
|
||||
import numpy as _np
|
||||
from scipy.interpolate import interp1d as _interp1d
|
||||
from scipy.ndimage import gaussian_filter as _gaussian_filter
|
||||
from ._brescount import bres_curve_count as _bres_curve_count
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
__all__ = ['grid_count']
|
||||
|
||||
|
||||
def grid_count(y, window_size, offset=0, size=None, fuzz=True, blur=0, bounds=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
`y` is the 1-d array of signal samples.
|
||||
|
||||
`window_size` is the number of samples to show horizontally in the
|
||||
eye diagram. Typically this is twice the number of samples in a
|
||||
"symbol" (i.e. in a data bit).
|
||||
|
||||
`offset` is the number of initial samples to skip before computing
|
||||
the eye diagram. This allows the overall phase of the diagram to
|
||||
be adjusted.
|
||||
|
||||
`size` must be a tuple of two integers. It sets the size of the
|
||||
array of counts, (height, width). The default is (800, 640).
|
||||
|
||||
`fuzz`: If True, the values in `y` are reinterpolated with a
|
||||
random "fuzz factor" before plotting in the eye diagram. This
|
||||
reduces an aliasing-like effect that arises with the use of
|
||||
Bresenham's algorithm.
|
||||
|
||||
`bounds` must be a tuple of two floating point values, (ymin, ymax).
|
||||
These set the y range of the returned array. If not given, the
|
||||
bounds are `(y.min() - 0.05*A, y.max() + 0.05*A)`, where `A` is
|
||||
`y.max() - y.min()`.
|
||||
|
||||
Return Value
|
||||
------------
|
||||
Returns a numpy array of integers.
|
||||
|
||||
"""
|
||||
# hash input params
|
||||
param_ob = (y, window_size, offset, size, fuzz, blur, bounds)
|
||||
param_hash = hashlib.md5(str(param_ob).encode()).hexdigest()
|
||||
cache_dir = Path.home()/".eyediagram"/".cache"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
if (cache_dir/param_hash).is_file():
|
||||
try:
|
||||
with h5py.File(cache_dir/param_hash, "r") as infile:
|
||||
counts = infile["counts"][:]
|
||||
if counts.len() != 0:
|
||||
return counts
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
if size is None:
|
||||
size = (800, 640)
|
||||
height, width = size
|
||||
dt = width / window_size
|
||||
counts = _np.zeros((width, height), dtype=_np.int32)
|
||||
|
||||
if bounds is None:
|
||||
ymin = y.min()
|
||||
ymax = y.max()
|
||||
yamp = ymax - ymin
|
||||
ymin = ymin - 0.05*yamp
|
||||
ymax = ymax + 0.05*yamp
|
||||
ymax = _np.ceil(ymax*10)/10
|
||||
ymin = _np.floor(ymin*10)/10
|
||||
else:
|
||||
ymin, ymax = bounds
|
||||
|
||||
start = offset
|
||||
while start + window_size < len(y):
|
||||
end = start + window_size
|
||||
yy = y[start:end+1]
|
||||
k = _np.arange(len(yy))
|
||||
xx = dt*k
|
||||
if fuzz:
|
||||
f = _interp1d(xx, yy, kind='cubic')
|
||||
jiggle = dt*(_np.random.beta(a=3, b=3, size=len(xx)-2) - 0.5)
|
||||
xx[1:-1] += jiggle
|
||||
yd = f(xx)
|
||||
else:
|
||||
yd = yy
|
||||
iyd = (height * (yd - ymin)/(ymax - ymin)).astype(_np.int32)
|
||||
_bres_curve_count(xx.astype(_np.int32), iyd, counts)
|
||||
|
||||
start = end
|
||||
|
||||
if blur != 0:
|
||||
counts = _gaussian_filter(counts, sigma=blur)
|
||||
|
||||
with h5py.File(cache_dir/param_hash, "w") as outfile:
|
||||
outfile.create_dataset("data", data=counts)
|
||||
|
||||
return counts
|
||||
@@ -25,13 +25,14 @@ import multiprocessing as mp
|
||||
# def __len__(self):
|
||||
# return len(self.indices)
|
||||
|
||||
|
||||
def load_from_file(datapath):
|
||||
if str(datapath).endswith('.h5'):
|
||||
if str(datapath).endswith(".h5"):
|
||||
symbols = None
|
||||
with h5py.File(datapath, "r") as infile:
|
||||
data = infile["data"][:]
|
||||
try:
|
||||
symbols = infile["symbols"][:]
|
||||
symbols = np.swapaxes(infile["symbols"][:], 0, 1)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
@@ -40,7 +41,7 @@ def load_from_file(datapath):
|
||||
return data, symbols
|
||||
|
||||
|
||||
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=1, device=None, dtype=None):
|
||||
def load_data(config_path, skipfirst=0, skiplast=0, symbols=None, real=False, normalize=1, device=None, dtype=None):
|
||||
filepath = Path(config_path)
|
||||
filepath = filepath.parent.glob(filepath.name)
|
||||
config = configparser.ConfigParser()
|
||||
@@ -55,15 +56,23 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=1, d
|
||||
|
||||
if symbols is None:
|
||||
symbols = int(config["glova"]["nos"]) - skipfirst
|
||||
|
||||
|
||||
data, orig_symbols = load_from_file(datapath)
|
||||
|
||||
data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)]
|
||||
orig_symbols = orig_symbols[skipfirst:symbols+skipfirst]
|
||||
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps))
|
||||
data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps - skiplast * sps)]
|
||||
orig_symbols = orig_symbols[skipfirst : symbols + skipfirst - skiplast]
|
||||
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps - skiplast * sps))
|
||||
|
||||
data *= np.sqrt(normalize)
|
||||
|
||||
launch_power = float(config["signal"]["laser_power"])
|
||||
output_power = float(config["signal"]["edfa_power"])
|
||||
|
||||
target_normalization = 10 ** (output_power / 10) / 10 ** (launch_power / 10)
|
||||
# target_normalization *= 0.5 # allow 50% power loss, so the network can ignore parts of the signal
|
||||
|
||||
data[:, 0:2] *= np.sqrt(target_normalization)
|
||||
|
||||
# if normalize:
|
||||
# # square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
||||
# a, b, c, d = data.T
|
||||
@@ -132,13 +141,15 @@ class FiberRegenerationDataset(Dataset):
|
||||
target_delay: float | int = 0,
|
||||
xy_delay: float | int = 0,
|
||||
drop_first: float | int = 0,
|
||||
drop_last=0,
|
||||
dtype: torch.dtype = None,
|
||||
real: bool = False,
|
||||
device=None,
|
||||
# osnr: float|None = None,
|
||||
polarisations = None,
|
||||
polarisations=None,
|
||||
randomise_polarisations: bool = False,
|
||||
repeat_randoms: int = 1,
|
||||
# cross_pol_interference: float = 0,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -172,6 +183,7 @@ class FiberRegenerationDataset(Dataset):
|
||||
assert drop_first >= 0, "drop_first must be non-negative"
|
||||
|
||||
self.randomise_polarisations = randomise_polarisations
|
||||
# self.cross_pol_interference = cross_pol_interference
|
||||
|
||||
data_raw = None
|
||||
self.config = None
|
||||
@@ -181,6 +193,7 @@ class FiberRegenerationDataset(Dataset):
|
||||
data, config, orig_syms = load_data(
|
||||
file_path,
|
||||
skipfirst=drop_first,
|
||||
skiplast=drop_last,
|
||||
symbols=kwargs.get("num_symbols", None),
|
||||
real=real,
|
||||
normalize=1000,
|
||||
@@ -192,7 +205,7 @@ class FiberRegenerationDataset(Dataset):
|
||||
self.orig_symbols = orig_syms
|
||||
else:
|
||||
self.orig_symbols = np.concat((self.orig_symbols, orig_syms), axis=-1)
|
||||
|
||||
|
||||
if data_raw is None:
|
||||
data_raw = data
|
||||
else:
|
||||
@@ -300,20 +313,18 @@ class FiberRegenerationDataset(Dataset):
|
||||
# fiber_out: [E_out_x, E_out_y, timestamps]
|
||||
|
||||
# add noise related to amplification necessary due to splitting of the signal
|
||||
gain_lin = output_dim*2
|
||||
edfa_nf = float(self.config["signal"]["edfa_nf"])
|
||||
nf_lin = 10**(edfa_nf/10)
|
||||
f0 = float(self.config["glova"]["f0"])
|
||||
|
||||
noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9
|
||||
|
||||
noise = torch.randn_like(fiber_out[:2, :])
|
||||
noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1))
|
||||
noise = noise * torch.sqrt(noise_add / noise_power)
|
||||
fiber_out[:2, :] += noise
|
||||
|
||||
# gain_lin = output_dim*2
|
||||
# gain_lin = 1
|
||||
# edfa_nf = float(self.config["signal"]["edfa_nf"])
|
||||
# nf_lin = 10**(edfa_nf/10)
|
||||
# f0 = float(self.config["glova"]["f0"])
|
||||
|
||||
# noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9
|
||||
|
||||
# noise = torch.randn_like(fiber_out[:2, :])
|
||||
# noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1))
|
||||
# noise = noise * torch.sqrt(noise_add / noise_power)
|
||||
# fiber_out[:2, :] += noise
|
||||
|
||||
# if osnr is None:
|
||||
# noisy = fiber_out[:2, :]
|
||||
@@ -324,7 +335,6 @@ class FiberRegenerationDataset(Dataset):
|
||||
|
||||
# fiber_out: [E_out_x, E_out_y, timestamps, E_out_x_noisy, E_out_y_noisy]
|
||||
|
||||
|
||||
if repeat_randoms > 1:
|
||||
fiber_in = fiber_in.repeat(1, 1, repeat_randoms)
|
||||
fiber_out = fiber_out.repeat(1, 1, repeat_randoms)
|
||||
@@ -334,12 +344,13 @@ class FiberRegenerationDataset(Dataset):
|
||||
|
||||
if self.randomise_polarisations:
|
||||
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms, device=fiber_out.device), 2) * torch.pi
|
||||
# start_angle = torch.rand(1) * 2 * torch.pi
|
||||
# angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
|
||||
start_angle = torch.rand(1) * 2 * torch.pi
|
||||
angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
|
||||
angles = torch.randn(data_raw.shape[-1], device=fiber_out.device) * 2*torch.pi / 36 # sigma = 10 degrees
|
||||
# self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi
|
||||
else:
|
||||
angles = torch.zeros(data_raw.shape[-1], device=fiber_out.device)
|
||||
|
||||
|
||||
sin = torch.sin(angles)
|
||||
cos = torch.cos(angles)
|
||||
rot = torch.stack([torch.stack([cos, -sin], dim=1), torch.stack([sin, cos], dim=1)], dim=2)
|
||||
@@ -353,16 +364,14 @@ class FiberRegenerationDataset(Dataset):
|
||||
# 1 E_in_y,
|
||||
# 2 timestamps
|
||||
|
||||
# fiber_out:
|
||||
# 0 E_out_x,
|
||||
# 1 E_out_y,
|
||||
# fiber_out:
|
||||
# 0 E_out_x,
|
||||
# 1 E_out_y,
|
||||
# 2 timestamps,
|
||||
# 3 E_out_x_rot,
|
||||
# 4 E_out_y_rot,
|
||||
# 3 E_out_x_rot,
|
||||
# 4 E_out_y_rot,
|
||||
# 5 angle
|
||||
|
||||
|
||||
|
||||
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
||||
# data layout
|
||||
# [ [E_in_x, E_in_y, timestamps],
|
||||
@@ -374,9 +383,12 @@ class FiberRegenerationDataset(Dataset):
|
||||
self.fiber_out = fiber_out.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||
self.fiber_out = self.fiber_out.movedim(-2, 0)
|
||||
|
||||
# if self.randomise_polarisations:
|
||||
# self.angles = torch.cumsum((torch.rand(self.fiber_out.shape[0]) - 0.5) * 2 * torch.pi * 2 / 5000, dim=0)
|
||||
|
||||
# self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||
# self.data = self.data.movedim(-2, 0)
|
||||
# self.angles = torch.zeros(self.data.shape[0])
|
||||
# self.angles = torch.zeros(self.data.shape[0])
|
||||
...
|
||||
# ...
|
||||
# -> [no_slices, 2, 3, samples_per_slice]
|
||||
@@ -390,14 +402,14 @@ class FiberRegenerationDataset(Dataset):
|
||||
|
||||
def __len__(self):
|
||||
return self.fiber_in.shape[0]
|
||||
|
||||
|
||||
def add_noise(self, data, osnr):
|
||||
osnr_lin = 10**(osnr/10)
|
||||
osnr_lin = 10 ** (osnr / 10)
|
||||
popt = torch.mean(data.abs().square().squeeze(), dim=-1)
|
||||
noise = torch.randn_like(data)
|
||||
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
|
||||
|
||||
mult = torch.sqrt(popt/(pn*osnr_lin))
|
||||
|
||||
mult = torch.sqrt(popt / (pn * osnr_lin))
|
||||
mult = mult * torch.eye(popt.shape[0], device=mult.device)
|
||||
mult = mult.to(dtype=noise.dtype)
|
||||
|
||||
@@ -406,7 +418,6 @@ class FiberRegenerationDataset(Dataset):
|
||||
noisy = data + noise
|
||||
return noisy
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, slice):
|
||||
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
|
||||
@@ -418,6 +429,10 @@ class FiberRegenerationDataset(Dataset):
|
||||
output_dim = self.output_dim // 2
|
||||
self.output_dim = output_dim * 2
|
||||
|
||||
if not self.polarisations:
|
||||
output_dim = 2 * output_dim
|
||||
|
||||
|
||||
fiber_in = self.fiber_in[idx].squeeze()
|
||||
fiber_out = self.fiber_out[idx].squeeze()
|
||||
|
||||
@@ -427,85 +442,35 @@ class FiberRegenerationDataset(Dataset):
|
||||
fiber_in = fiber_in.view(fiber_in.shape[0], output_dim, -1)
|
||||
fiber_out = fiber_out.view(fiber_out.shape[0], output_dim, -1)
|
||||
|
||||
|
||||
# data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
|
||||
|
||||
# data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
|
||||
|
||||
# angle = self.angles[idx]
|
||||
|
||||
# fiber_in:
|
||||
# 0 E_in_x,
|
||||
# 1 E_in_y,
|
||||
# 2 timestamps
|
||||
|
||||
# fiber_out:
|
||||
# 0 E_out_x,
|
||||
# 1 E_out_y,
|
||||
# 2 timestamps,
|
||||
# 3 E_out_x_rot,
|
||||
# 4 E_out_y_rot,
|
||||
# 5 angle
|
||||
|
||||
center_angle = fiber_out[0, output_dim // 2, 0]
|
||||
center_angle = fiber_out[5, output_dim // 2, 0]
|
||||
angles = fiber_out[5, :, 0]
|
||||
plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone()
|
||||
plot_data_rot = fiber_out[3:5, output_dim // 2, 0].detach().clone()
|
||||
data = fiber_out[0:2, :, 0]
|
||||
# fiber_out_plot_clean = fiber_out[:2, output_dim // 2, 0].detach().clone()
|
||||
plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone()
|
||||
|
||||
|
||||
# if self.polarisations:
|
||||
# rot = int(np.random.randint(2)*2-1)
|
||||
# pol_flipped_data[0:1, :] = rot*data[0, :]
|
||||
# pol_flipped_data[1, :] = rot*data[1, :]
|
||||
# plot_data_rot[0] = rot*plot_data_rot[0]
|
||||
# plot_data_rot[1] = rot*plot_data_rot[1]
|
||||
# center_angle = center_angle + (rot - 1) * torch.pi/2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
|
||||
# angles = angles + (rot - 1) * torch.pi/2
|
||||
|
||||
|
||||
# if self.randomise_polarisations:
|
||||
# data = data.mT
|
||||
# c = torch.cos(angle).unsqueeze(-1)
|
||||
# s = torch.sin(angle).unsqueeze(-1)
|
||||
# rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1)
|
||||
# data = torch.bmm(data.mT.unsqueeze(0), rot.to(dtype=data.dtype)).squeeze(-1)
|
||||
...
|
||||
|
||||
# angle = torch.zeros_like(angle)
|
||||
|
||||
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter)
|
||||
# angle_data = fiber_out[:2, :, :].reshape(2, -1).mean(dim=1).repeat(1, output_dim)
|
||||
# angle_data2 = self.complex_max(fiber_out[:2, :, :].reshape(2, -1)).repeat(1, output_dim)
|
||||
# sop = self.polarimeter(plot_data)
|
||||
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
|
||||
# angle = data_slice[1, 3, self.output_dim // 2, 0].real
|
||||
target = fiber_in[:2, output_dim // 2, 0]
|
||||
plot_target = fiber_in[:2, output_dim // 2, 0].detach().clone()
|
||||
target_timestamp = fiber_in[2, output_dim // 2, 0].real
|
||||
...
|
||||
|
||||
if self.polarisations:
|
||||
rot = int(np.random.randint(2)*2-1)
|
||||
data = rot*data
|
||||
target = rot*target
|
||||
plot_data_rot = rot*plot_data_rot
|
||||
center_angle = center_angle + (rot - 1) * torch.pi/2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
|
||||
angles = angles + (rot - 1) * torch.pi/2
|
||||
rot = int(np.random.randint(2) * 2 - 1)
|
||||
data = rot * data
|
||||
target = rot * target
|
||||
plot_data_rot = rot * plot_data_rot
|
||||
center_angle = center_angle + (rot - 1) * torch.pi / 2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
|
||||
angles = angles + (rot - 1) * torch.pi / 2
|
||||
|
||||
pol_flipped_data = -data
|
||||
pol_flipped_target = -target
|
||||
|
||||
# data_timestamps = data[-1,:].real
|
||||
# data = data[:-1, :]
|
||||
# target_timestamp = target[-1].real
|
||||
# target = target[:-1]
|
||||
# plot_data = plot_data[:-1]
|
||||
|
||||
# transpose to interleave the x and y data in the output tensor
|
||||
data = data.transpose(0, 1).flatten().squeeze()
|
||||
data = data / torch.sqrt(torch.ones(1) * len(data)) # power loss due to splitting
|
||||
pol_flipped_data = pol_flipped_data.transpose(0, 1).flatten().squeeze()
|
||||
pol_flipped_data = pol_flipped_data / torch.sqrt(
|
||||
torch.ones(1) * len(pol_flipped_data)
|
||||
) # power loss due to splitting
|
||||
# angle_data = angle_data.transpose(0, 1).flatten().squeeze()
|
||||
# angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
|
||||
center_angle = center_angle.flatten().squeeze()
|
||||
@@ -526,10 +491,10 @@ class FiberRegenerationDataset(Dataset):
|
||||
"y": target,
|
||||
"y_flipped": pol_flipped_target,
|
||||
"y_stacked": torch.cat([target, pol_flipped_target], dim=-1),
|
||||
# "center_angle": center_angle,
|
||||
# "angles": angles,
|
||||
"center_angle": center_angle,
|
||||
"angles": angles,
|
||||
"mean_angle": angles.mean(),
|
||||
# "sop": sop,
|
||||
# "sop": sop,
|
||||
# "angle_data": angle_data,
|
||||
# "angle_data2": angle_data2,
|
||||
"timestamp": target_timestamp,
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
import h5py
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
# from cmap import Colormap as cm
|
||||
import numpy as np
|
||||
from scipy.cluster.vq import kmeans2
|
||||
import warnings
|
||||
import multiprocessing
|
||||
|
||||
from rich.traceback import install
|
||||
from rich import pretty
|
||||
from rich import print
|
||||
|
||||
install()
|
||||
pretty.install()
|
||||
# from rich import pretty
|
||||
# from rich import print
|
||||
|
||||
# pretty.install()
|
||||
|
||||
|
||||
def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
|
||||
@@ -21,6 +28,7 @@ def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
|
||||
xaxis = np.arange(0, len(signal)) / sps
|
||||
return np.vstack([xaxis, signal])
|
||||
|
||||
|
||||
def create_symbol_sequence(n_symbols, skew=1):
|
||||
np.random.seed(42)
|
||||
data = np.random.randint(0, 4, n_symbols) / 4
|
||||
@@ -39,6 +47,14 @@ def generate_signal(data, sps):
|
||||
signal = np.convolve(data_padded, wavelet)
|
||||
signal = np.cumsum(signal)
|
||||
signal = signal[sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
||||
mi, ma = np.min(signal), np.max(signal)
|
||||
|
||||
signal = (signal - mi) / (ma - mi)
|
||||
|
||||
mod = 0.8
|
||||
|
||||
signal *= mod
|
||||
signal += 1 - mod
|
||||
|
||||
return signal
|
||||
|
||||
@@ -49,8 +65,8 @@ def normalization_with_noise(signal, noise=0):
|
||||
signal += awgn
|
||||
|
||||
# min-max normalization
|
||||
signal = signal - np.min(signal)
|
||||
signal = signal / np.max(signal)
|
||||
# signal = signal - np.min(signal)
|
||||
# signal = signal / np.max(signal)
|
||||
return signal
|
||||
|
||||
|
||||
@@ -68,84 +84,248 @@ def generate_wavelet(sps, oversample=3):
|
||||
|
||||
|
||||
class eye_diagram:
|
||||
def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4, multithreaded=True):
|
||||
def __init__(
|
||||
self,
|
||||
data,
|
||||
*,
|
||||
channel_names=None,
|
||||
horizontal_bins=256,
|
||||
vertical_bins=1000,
|
||||
n_levels=4,
|
||||
multithreaded=True,
|
||||
save_file_or_dir=None,
|
||||
):
|
||||
# data has shape [channels, 2, samples]
|
||||
# each sample has a timestamp and a value
|
||||
if data.ndim == 2:
|
||||
data = data[np.newaxis, :, :]
|
||||
self.channel_names = channel_names
|
||||
self.raw_data = data
|
||||
self.channels = data.shape[0]
|
||||
|
||||
self.y_bins = np.zeros(1)
|
||||
self.x_bins = np.zeros(1)
|
||||
self.eye_data = np.zeros(1)
|
||||
self.channel_names = channel_names
|
||||
self.n_channels = data.shape[0]
|
||||
self.n_levels = n_levels
|
||||
self.eye_stats = [{"success": False} for _ in range(self.channels)]
|
||||
self.eye_stats = [{"success": False} for _ in range(self.n_channels)]
|
||||
self.horizontal_bins = horizontal_bins
|
||||
self.vertical_bins = vertical_bins
|
||||
self.multi_threaded = multithreaded
|
||||
self.analysed = False
|
||||
self.eye_built = False
|
||||
|
||||
def generate_eye_data(self):
|
||||
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
|
||||
self.y_bins = np.zeros((self.channels, self.vertical_bins))
|
||||
self.eye_data = np.zeros((self.channels, self.vertical_bins, self.horizontal_bins))
|
||||
datas = [self.raw_data[i] for i in range(self.channels)]
|
||||
if self.multi_threaded:
|
||||
with multiprocessing.Pool() as pool:
|
||||
results = pool.map(self.generate_eye_data_single, datas)
|
||||
for i, result in enumerate(results):
|
||||
self.eye_data[i], self.y_bins[i] = result
|
||||
self.save_file = save_file_or_dir
|
||||
|
||||
def load_data(self, file=None):
|
||||
file = self.save_file if file is None else file
|
||||
|
||||
if file is None:
|
||||
raise FileNotFoundError("No file specified.")
|
||||
|
||||
self.save_file = str(file)
|
||||
# self.file_or_dir = self.save_file
|
||||
with h5py.File(file, "r") as infile:
|
||||
self.y_bins = infile["y_bins"][:]
|
||||
self.x_bins = infile["x_bins"][:]
|
||||
self.eye_data = infile["eye_data"][:]
|
||||
self.channel_names = infile.attrs["channel_names"]
|
||||
self.n_channels = infile.attrs["n_channels"]
|
||||
self.n_levels = infile.attrs["n_levels"]
|
||||
self.eye_stats = infile.attrs["eye_stats"]
|
||||
self.eye_stats = [json.loads(stat) for stat in self.eye_stats]
|
||||
self.horizontal_bins = infile.attrs["horizontal_bins"]
|
||||
self.vertical_bins = infile.attrs["vertical_bins"]
|
||||
self.multi_threaded = infile.attrs["multithreaded"]
|
||||
self.analysed = infile.attrs["analysed"]
|
||||
self.eye_built = infile.attrs["eye_built"]
|
||||
|
||||
def save_data(self, file_or_dir=None):
|
||||
file_or_dir = self.save_file if file_or_dir is None else file_or_dir
|
||||
if file_or_dir is None:
|
||||
file = Path(f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5")
|
||||
elif Path(file_or_dir).is_dir():
|
||||
file = Path(file_or_dir) / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5"
|
||||
else:
|
||||
for i, data in enumerate(datas):
|
||||
self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
|
||||
self.eye_built = True
|
||||
|
||||
file = Path(file_or_dir)
|
||||
|
||||
# file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.save_file = str(file)
|
||||
|
||||
with h5py.File(file, "w") as outfile:
|
||||
outfile.create_dataset("eye_data", data=self.eye_data)
|
||||
outfile.create_dataset("y_bins", data=self.y_bins)
|
||||
outfile.create_dataset("x_bins", data=self.x_bins)
|
||||
outfile.attrs["channel_names"] = self.channel_names
|
||||
outfile.attrs["n_channels"] = self.n_channels
|
||||
outfile.attrs["n_levels"] = self.n_levels
|
||||
self.eye_stats = eye_diagram.convert_arrays(self.eye_stats)
|
||||
outfile.attrs["eye_stats"] = [json.dumps(stat) for stat in self.eye_stats]
|
||||
outfile.attrs["horizontal_bins"] = self.horizontal_bins
|
||||
outfile.attrs["vertical_bins"] = self.vertical_bins
|
||||
outfile.attrs["multithreaded"] = self.multi_threaded
|
||||
outfile.attrs["analysed"] = self.analysed
|
||||
outfile.attrs["eye_built"] = self.eye_built
|
||||
|
||||
@staticmethod
|
||||
def convert_arrays(input_object):
|
||||
"""
|
||||
convert ndarrays in (nested) dict to lists
|
||||
"""
|
||||
|
||||
if isinstance(input_object, np.ndarray):
|
||||
return input_object.tolist()
|
||||
elif isinstance(input_object, list):
|
||||
return [eye_diagram.convert_arrays(old) for old in input_object]
|
||||
elif isinstance(input_object, tuple):
|
||||
return tuple(eye_diagram.convert_arrays(old) for old in input_object)
|
||||
elif isinstance(input_object, dict):
|
||||
dict_out = {}
|
||||
for key, value in input_object.items():
|
||||
dict_out[key] = eye_diagram.convert_arrays(value)
|
||||
return dict_out
|
||||
return input_object
|
||||
|
||||
def generate_eye_data(
|
||||
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
|
||||
):
|
||||
# modes:
|
||||
# default: try to load eye data from file, if not found, generate and save
|
||||
# load: try to load eye data from file, if not found, generate but don't save
|
||||
# save: generate eye data and save
|
||||
update_save = True
|
||||
if mode == "load":
|
||||
self.load_data(file_or_dir)
|
||||
update_save = False
|
||||
elif mode == "default":
|
||||
try:
|
||||
self.load_data(file_or_dir)
|
||||
update_save = False
|
||||
except (FileNotFoundError, IsADirectoryError):
|
||||
pass
|
||||
|
||||
if not self.eye_built:
|
||||
update_save = True
|
||||
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
|
||||
self.y_bins = np.zeros((self.n_channels, self.vertical_bins))
|
||||
self.eye_data = np.zeros((self.n_channels, self.vertical_bins, self.horizontal_bins))
|
||||
datas = [self.raw_data[i] for i in range(self.n_channels)]
|
||||
if self.multi_threaded:
|
||||
with multiprocessing.Pool() as pool:
|
||||
results = pool.map(self.generate_eye_data_single, datas)
|
||||
for i, result in enumerate(results):
|
||||
self.eye_data[i], self.y_bins[i] = result
|
||||
else:
|
||||
for i, data in enumerate(datas):
|
||||
self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
|
||||
self.eye_built = True
|
||||
|
||||
if mode == "save" or (mode == "default" and update_save):
|
||||
self.save_data(file_or_dir)
|
||||
|
||||
def generate_eye_data_single(self, data):
|
||||
eye_data = np.zeros((self.vertical_bins, self.horizontal_bins))
|
||||
data_min = np.min(data[1, :])
|
||||
data_max = np.max(data[1, :])
|
||||
# round down/up to 1 decimal
|
||||
data_min = np.floor(data_min*10)/10
|
||||
data_max = np.ceil(data_max*10)/10
|
||||
# data_range = data_max - data_min
|
||||
# data_min -= 0.1 * data_range
|
||||
# data_max += 0.1 * data_range
|
||||
# data_min = -0.05
|
||||
# data_max += 0.05
|
||||
# data[1,:] -= np.min(data[1, :])
|
||||
# data[1,:] /= np.max(data[1, :])
|
||||
# data_min = 0
|
||||
# data_max = 1
|
||||
y_bins = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
|
||||
t_vals = data[0, :] % 2
|
||||
val_vals = data[1, :]
|
||||
t_vals = data[0, :] % 2 # + np.random.randn(*data[0, :].shape) * 1 / (512)
|
||||
val_vals = data[1, :] # + np.random.randn(*data[1, :].shape) * 1 / (320)
|
||||
x_indices = np.digitize(t_vals, self.x_bins) - 1
|
||||
y_indices = np.digitize(val_vals, y_bins) - 1
|
||||
np.add.at(eye_data, (y_indices, x_indices), 1)
|
||||
return eye_data, y_bins
|
||||
|
||||
def plot(self, title="Eye Diagram", stats=True, all_stats=True, show=True):
|
||||
def plot(
|
||||
self,
|
||||
title="Eye Diagram",
|
||||
stats=True,
|
||||
all_stats=True,
|
||||
show=True,
|
||||
mode: Literal["default", "load", "save", "nosave"] = "default",
|
||||
# save_images = False,
|
||||
# image_dir = None,
|
||||
# cmap=None,
|
||||
):
|
||||
if stats and not self.analysed:
|
||||
self.analyse(mode=mode)
|
||||
if not self.eye_built:
|
||||
self.generate_eye_data()
|
||||
self.generate_eye_data(mode=mode)
|
||||
cmap = LinearSegmentedColormap.from_list(
|
||||
"eyemap",
|
||||
[(0, "white"), (0.1, "blue"), (0.2, "cyan"), (0.5, "green"), (0.8, "yellow"), (0.9, "red"), (1, "magenta")],
|
||||
[
|
||||
(0, "#FFFFFF00"),
|
||||
(0.1, "blue"),
|
||||
(0.2, "cyan"),
|
||||
(0.5, "green"),
|
||||
(0.8, "yellow"),
|
||||
(0.9, "red"),
|
||||
(1, "magenta"),
|
||||
],
|
||||
)
|
||||
if self.channels % 2 == 0:
|
||||
# cmap = cm('google:turbo_r' if cmap is None else cmap)
|
||||
# first = cmap(-1)
|
||||
# cmap = cmap.to_mpl()
|
||||
# cmap.set_under(first, alpha=0)
|
||||
if self.n_channels % 2 == 0:
|
||||
rows = 2
|
||||
cols = self.channels // 2
|
||||
cols = self.n_channels // 2
|
||||
else:
|
||||
cols = int(np.ceil(np.sqrt(self.channels)))
|
||||
rows = int(np.ceil(self.channels / cols))
|
||||
cols = int(np.ceil(np.sqrt(self.n_channels)))
|
||||
rows = int(np.ceil(self.n_channels / cols))
|
||||
fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False)
|
||||
fig.suptitle(title)
|
||||
fig.tight_layout()
|
||||
ax = np.atleast_1d(ax).transpose().flatten()
|
||||
for i in range(self.channels):
|
||||
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}")
|
||||
if (i+1) % rows == 0:
|
||||
for i in range(self.n_channels):
|
||||
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i + 1}")
|
||||
if (i + 1) % rows == 0:
|
||||
ax[i].set_xlabel("Symbol")
|
||||
if i < rows:
|
||||
ax[i].set_ylabel("Amplitude")
|
||||
ax[i].grid()
|
||||
ax[i].set_axisbelow(True)
|
||||
ax[i].imshow(
|
||||
self.eye_data[i],
|
||||
self.eye_data[i] - 0.1,
|
||||
origin="lower",
|
||||
aspect="auto",
|
||||
cmap=cmap,
|
||||
extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
|
||||
interpolation="gaussian",
|
||||
vmin=0,
|
||||
zorder=3,
|
||||
)
|
||||
ax[i].set_xlim((self.x_bins[0], self.x_bins[-1]))
|
||||
ymin = np.min(self.y_bins[:, 0])
|
||||
ymax = np.max(self.y_bins[:, -1])
|
||||
yspan = ymax - ymin
|
||||
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
|
||||
# if save_images:
|
||||
# image_dir = "images_out" if image_dir is None else image_dir
|
||||
# image_path = Path(image_dir) / (slugify(f"{datetime.now().strftime("%Y%m%d_%H%M%S")}_{title.replace(" ","_")}_{self.channel_names[i].replace(" ", "_") if self.channel_names is not None else f"{i + 1}"}_{ymin:.1f}_{ymax:.1f}") + ".png")
|
||||
# image_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# # plt.imsave(
|
||||
# # image_path,
|
||||
# # self.eye_data[i] - 0.1,
|
||||
# # origin="lower",
|
||||
# # # aspect="auto",
|
||||
# # cmap=cmap,
|
||||
# # # extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
|
||||
# # # interpolation="gaussian",
|
||||
# # vmin=0,
|
||||
# # # zorder=3,
|
||||
# # )
|
||||
if stats and self.eye_stats[i]["success"]:
|
||||
# # add min_area above the plot
|
||||
# ax[i].annotate(
|
||||
@@ -159,7 +339,7 @@ class eye_diagram:
|
||||
|
||||
if all_stats:
|
||||
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
|
||||
y_ticks = (*self.eye_stats[i]["levels"],*self.eye_stats[i]["thresholds"])
|
||||
y_ticks = (*self.eye_stats[i]["levels"], *self.eye_stats[i]["thresholds"])
|
||||
# y_ticks = np.sort(y_ticks)
|
||||
ax[i].set_yticks(y_ticks)
|
||||
# add arrows for amplitudes
|
||||
@@ -230,24 +410,24 @@ class eye_diagram:
|
||||
if show:
|
||||
plt.show()
|
||||
return fig
|
||||
|
||||
|
||||
@staticmethod
|
||||
def calculate_thresholds(levels):
|
||||
ret = np.cumsum(levels, dtype=float)
|
||||
ret[2:] = ret[2:] - ret[:-2]
|
||||
return ret[1:]/2
|
||||
return ret[1:] / 2
|
||||
|
||||
def analyse_single(self, data, index):
|
||||
warnings.filterwarnings("error")
|
||||
eye_stats = {}
|
||||
eye_stats["channel_name"] = str(index+1) if self.channel_names is None else self.channel_names[index]
|
||||
eye_stats["channel_name"] = str(index + 1) if self.channel_names is None else self.channel_names[index]
|
||||
try:
|
||||
approx_levels = eye_diagram.approximate_levels(data, self.n_levels)
|
||||
|
||||
time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels)
|
||||
|
||||
eye_stats["time_midpoint_calc"] = (time_bounds[0] + time_bounds[1]) / 2
|
||||
eye_stats["time_midpoint"] = 1.0
|
||||
eye_stats["time_midpoint"] = float((time_bounds[0] + time_bounds[1]) / 2)
|
||||
# eye_stats["time_midpoint"] = 1.0
|
||||
|
||||
eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels(
|
||||
data, approx_levels, time_bounds
|
||||
@@ -257,9 +437,7 @@ class eye_diagram:
|
||||
|
||||
eye_stats["amplitudes"] = np.diff(eye_stats["levels"])
|
||||
|
||||
eye_stats["heights"] = eye_diagram.calculate_eye_heights(
|
||||
eye_stats["amplitude_clusters"]
|
||||
)
|
||||
eye_stats["heights"] = eye_diagram.calculate_eye_heights(eye_stats["amplitude_clusters"])
|
||||
|
||||
eye_stats["widths"], eye_stats["time_clusters"] = eye_diagram.calculate_eye_widths(
|
||||
data, eye_stats["levels"]
|
||||
@@ -291,17 +469,39 @@ class eye_diagram:
|
||||
warnings.resetwarnings()
|
||||
return eye_stats
|
||||
|
||||
def analyse(
|
||||
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
|
||||
):
|
||||
# modes:
|
||||
# default: try to load eye data from file, if not found, generate and save
|
||||
# load: try to load eye data from file, if not found, generate but don't save
|
||||
# save: generate eye data and save
|
||||
update_save = True
|
||||
if mode == "load":
|
||||
self.load_data(file_or_dir)
|
||||
update_save = False
|
||||
elif mode == "default":
|
||||
try:
|
||||
self.load_data(file_or_dir)
|
||||
update_save = False
|
||||
except (FileNotFoundError, IsADirectoryError):
|
||||
pass
|
||||
|
||||
def analyse(self):
|
||||
self.eye_stats = []
|
||||
if self.multi_threaded:
|
||||
with multiprocessing.Pool() as pool:
|
||||
results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.channels)])
|
||||
for i, result in enumerate(results):
|
||||
self.eye_stats.append(result)
|
||||
else:
|
||||
for i in range(self.channels):
|
||||
self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
|
||||
if not self.analysed:
|
||||
update_save = True
|
||||
self.eye_stats = []
|
||||
if self.multi_threaded:
|
||||
with multiprocessing.Pool() as pool:
|
||||
results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.n_channels)])
|
||||
for i, result in enumerate(results):
|
||||
self.eye_stats.append(result)
|
||||
else:
|
||||
for i in range(self.n_channels):
|
||||
self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
|
||||
self.analysed = True
|
||||
|
||||
if mode == "save" or (mode == "default" and update_save):
|
||||
self.save_data(file_or_dir)
|
||||
|
||||
@staticmethod
|
||||
def approximate_levels(data, levels):
|
||||
@@ -443,7 +643,7 @@ class eye_diagram:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
length = int(2**14)
|
||||
length = int(2**16)
|
||||
# data = generate_sample_data(length, noise=1)
|
||||
# data1 = generate_sample_data(length, noise=0.01)
|
||||
# data2 = generate_sample_data(length, noise=0.01, skew=1.2)
|
||||
@@ -451,13 +651,13 @@ if __name__ == "__main__":
|
||||
|
||||
# data = np.stack([data, data1, data2, data3])
|
||||
|
||||
data = generate_sample_data(length, noise=0.005)
|
||||
eye = eye_diagram(data, horizontal_bins=256, vertical_bins=256)
|
||||
eye.analyse()
|
||||
attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area")
|
||||
for i, channel in enumerate(eye.eye_stats):
|
||||
print(f"Channel {i}")
|
||||
print_data = {attr: channel[attr] for attr in attrs}
|
||||
print(print_data)
|
||||
data = generate_sample_data(length, noise=0.0000)
|
||||
eye = eye_diagram(data, horizontal_bins=1056, vertical_bins=1200)
|
||||
eye.plot(mode="nosave", stats=False)
|
||||
# attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area")
|
||||
# for i, channel in enumerate(eye.eye_stats):
|
||||
# print(f"Channel {i}")
|
||||
# print_data = {attr: channel[attr] for attr in attrs}
|
||||
# print(print_data)
|
||||
|
||||
eye.plot()
|
||||
# eye.plot()
|
||||
|
||||
122
src/single-core-regen/util/mpl.py
Normal file
122
src/single-core-regen/util/mpl.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
|
||||
# This software is licensed according to the "BSD 2-clause" license.
|
||||
|
||||
# modified by Joseph Hopfmüller in 2025,
|
||||
# for integration into optical regeneration analysis scripts
|
||||
|
||||
from pathlib import Path
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
import matplotlib.colors as colors
|
||||
import numpy as _np
|
||||
from .core import grid_count as _grid_count
|
||||
import matplotlib.pyplot as _plt
|
||||
import numpy as np
|
||||
from scipy.ndimage import gaussian_filter
|
||||
|
||||
|
||||
# from ._common import _common_doc
|
||||
|
||||
|
||||
__all__ = ["eyediagram"] # , 'eyediagram_lines']
|
||||
|
||||
|
||||
# def eyediagram_lines(y, window_size, offset=0, **plotkwargs):
|
||||
# """
|
||||
# Plot an eye diagram using matplotlib by repeatedly calling the `plot`
|
||||
# function.
|
||||
# <common>
|
||||
|
||||
# """
|
||||
# start = offset
|
||||
# while start < len(y):
|
||||
# end = start + window_size
|
||||
# if end > len(y):
|
||||
# end = len(y)
|
||||
# yy = y[start:end+1]
|
||||
# _plt.plot(_np.arange(len(yy)), yy, 'k', **plotkwargs)
|
||||
# start = end
|
||||
|
||||
# eyediagram_lines.__doc__ = eyediagram_lines.__doc__.replace("<common>",
|
||||
# _common_doc)
|
||||
|
||||
|
||||
eyemap = LinearSegmentedColormap.from_list(
|
||||
"eyemap",
|
||||
[
|
||||
(0, "#0000FF00"),
|
||||
(0.1, "blue"),
|
||||
(0.2, "cyan"),
|
||||
(0.5, "green"),
|
||||
(0.8, "yellow"),
|
||||
(0.9, "red"),
|
||||
(1, "magenta"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def eyediagram(
|
||||
y,
|
||||
window_size,
|
||||
offset=0,
|
||||
colorbar=False,
|
||||
show=False,
|
||||
save_im=False,
|
||||
overwrite=False,
|
||||
blur: int | bool = True,
|
||||
save_path="out.png",
|
||||
bounds=None,
|
||||
**imshowkwargs,
|
||||
):
|
||||
"""
|
||||
Plot an eye diagram using matplotlib by creating an image and calling
|
||||
the `imshow` function.
|
||||
<common>
|
||||
"""
|
||||
if bounds is None:
|
||||
ymax = y.max()
|
||||
ymin = y.min()
|
||||
yamp = ymax - ymin
|
||||
ymin = ymin - 0.05 * yamp
|
||||
ymax = ymax + 0.05 * yamp
|
||||
ymin = np.floor(ymin * 10) / 10
|
||||
ymax = np.ceil(ymax * 10) / 10
|
||||
bounds = (ymin, ymax)
|
||||
counts = _grid_count(y, window_size, offset, bounds=bounds, size=(1000, 1200), blur=int(blur))
|
||||
counts = counts.astype(_np.float32)
|
||||
origin = imshowkwargs.pop("origin", "lower")
|
||||
cmap: colors.Colormap = imshowkwargs.pop("cmap", eyemap)
|
||||
vmin = imshowkwargs.pop("vmin", 1)
|
||||
vmax = imshowkwargs.pop("vmax", None)
|
||||
cmap.set_under("white", alpha=0)
|
||||
|
||||
if show:
|
||||
_plt.imshow(
|
||||
counts.T[::-1, :],
|
||||
extent=[0, 2, *bounds],
|
||||
origin=origin,
|
||||
cmap=cmap,
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
**imshowkwargs,
|
||||
)
|
||||
_plt.grid()
|
||||
if colorbar:
|
||||
_plt.colorbar()
|
||||
|
||||
if Path(save_path).is_file() and not overwrite:
|
||||
save_im = False
|
||||
if save_im:
|
||||
from PIL import Image
|
||||
arr = counts.T[::-1, :]
|
||||
if origin == "lower":
|
||||
arr = arr[::-1]
|
||||
arr = (arr-arr.min())/(arr.max()-arr.min())
|
||||
image = Image.fromarray((cmap(arr)[:, :, :] * 255).astype(np.uint8))
|
||||
image.save(save_path)
|
||||
# print("-")
|
||||
|
||||
if show:
|
||||
_plt.show()
|
||||
|
||||
|
||||
# eyediagram.__doc__ = eyediagram.__doc__.replace("<common>", _common_doc)
|
||||
Reference in New Issue
Block a user